Compare commits
24 Commits
parth/samp
...
v0.6.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b51e0f397c | ||
|
|
b42970063d | ||
|
|
493385eb3e | ||
|
|
9876c9faa4 | ||
|
|
4e415029b3 | ||
|
|
e172f095ba | ||
|
|
c001b98087 | ||
|
|
23fc8e92eb | ||
|
|
4059a297a6 | ||
|
|
66b2539238 | ||
|
|
ef27d52e79 | ||
|
|
b2a465296d | ||
|
|
5d097277ef | ||
|
|
071a9872cb | ||
|
|
0bd0454ea7 | ||
|
|
01aa788722 | ||
|
|
ead27aa9fe | ||
|
|
b816ff86c9 | ||
|
|
e5d84fb90b | ||
|
|
dd66712e31 | ||
|
|
f66216e399 | ||
|
|
f4f0992b6e | ||
|
|
1feff61977 | ||
|
|
5e0b904e88 |
@@ -86,9 +86,9 @@ if(CMAKE_CUDA_COMPILER)
|
||||
)
|
||||
endif()
|
||||
|
||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a):xnack[+-]$"
|
||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a|1200|1201):xnack[+-]$"
|
||||
CACHE STRING
|
||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a):xnack[+-]$\"."
|
||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a|1200|1201):xnack[+-]$\"."
|
||||
)
|
||||
|
||||
check_language(HIP)
|
||||
@@ -97,7 +97,7 @@ if(CMAKE_HIP_COMPILER)
|
||||
|
||||
find_package(hip REQUIRED)
|
||||
if(NOT AMDGPU_TARGETS)
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012])$")
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$")
|
||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||
endif()
|
||||
|
||||
@@ -56,7 +56,7 @@
|
||||
"name": "ROCm 6",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -285,6 +285,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||
@@ -324,6 +325,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
||||
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
||||
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
||||
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
|
||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
|
||||
@@ -394,6 +396,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
|
||||
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -433,7 +437,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
||||
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
||||
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
|
||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
|
||||
46
api/types.go
46
api/types.go
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// StatusError is an error with an HTTP status code and message.
|
||||
@@ -81,7 +82,7 @@ type GenerateRequest struct {
|
||||
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -106,7 +107,7 @@ type ChatRequest struct {
|
||||
Tools `json:"tools,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@@ -260,7 +261,7 @@ type EmbedRequest struct {
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
// EmbedResponse is the response from [Client.Embed].
|
||||
@@ -286,7 +287,7 @@ type EmbeddingRequest struct {
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse is the response from [Client.Embeddings].
|
||||
@@ -332,7 +333,7 @@ type ShowRequest struct {
|
||||
Template string `json:"template"`
|
||||
Verbose bool `json:"verbose"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Deprecated: set the model name with Model instead
|
||||
Name string `json:"name"`
|
||||
@@ -340,17 +341,18 @@ type ShowRequest struct {
|
||||
|
||||
// ShowResponse is the response returned from [Client.Show].
|
||||
type ShowResponse struct {
|
||||
License string `json:"license,omitempty"`
|
||||
Modelfile string `json:"modelfile,omitempty"`
|
||||
Parameters string `json:"parameters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
License string `json:"license,omitempty"`
|
||||
Modelfile string `json:"modelfile,omitempty"`
|
||||
Parameters string `json:"parameters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
}
|
||||
|
||||
// CopyRequest is the request passed to [Client.Copy].
|
||||
@@ -503,7 +505,7 @@ func (m *Metrics) Summary() {
|
||||
}
|
||||
}
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
func (opts *Options) FromMap(m map[string]any) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||
|
||||
@@ -560,12 +562,12 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
}
|
||||
field.SetString(val)
|
||||
case reflect.Slice:
|
||||
// JSON unmarshals to []interface{}, not []string
|
||||
val, ok := val.([]interface{})
|
||||
// JSON unmarshals to []any, not []string
|
||||
val, ok := val.([]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type array", key)
|
||||
}
|
||||
// convert []interface{} to []string
|
||||
// convert []any to []string
|
||||
slice := make([]string, len(val))
|
||||
for i, item := range val {
|
||||
str, ok := item.(string)
|
||||
@@ -672,7 +674,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
}
|
||||
|
||||
// FormatParams converts specified parameter options to their correct types
|
||||
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||
func FormatParams(params map[string][]string) (map[string]any, error) {
|
||||
opts := Options{}
|
||||
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
||||
@@ -686,7 +688,7 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||
}
|
||||
}
|
||||
|
||||
out := make(map[string]interface{})
|
||||
out := make(map[string]any)
|
||||
// iterate params and set values based on json struct tags
|
||||
for key, vals := range params {
|
||||
if opt, ok := jsonOpts[key]; !ok {
|
||||
|
||||
@@ -134,7 +134,7 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var oMap map[string]interface{}
|
||||
var oMap map[string]any
|
||||
err := json.Unmarshal([]byte(test.req), &oMap)
|
||||
require.NoError(t, err)
|
||||
opts := DefaultOptions()
|
||||
|
||||
@@ -92,7 +92,7 @@ func BenchmarkColdStart(b *testing.B) {
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
@@ -155,7 +155,7 @@ func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||
&api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
|
||||
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
|
||||
},
|
||||
func(api.GenerateResponse) error { return nil },
|
||||
)
|
||||
|
||||
19
cmd/cmd.go
19
cmd/cmd.go
@@ -18,6 +18,7 @@ import (
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -267,7 +268,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
opts := runOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]interface{}{},
|
||||
Options: map[string]any{},
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
@@ -339,6 +340,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||
|
||||
// TODO: remove the projector info and vision info checks below,
|
||||
// these are left in for backwards compatibility with older servers
|
||||
// that don't have the capabilities field in the model info
|
||||
if len(info.ProjectorInfo) != 0 {
|
||||
opts.MultiModal = true
|
||||
}
|
||||
@@ -669,6 +675,15 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
return
|
||||
})
|
||||
|
||||
if len(resp.Capabilities) > 0 {
|
||||
tableRender("Capabilities", func() (rows [][]string) {
|
||||
for _, capability := range resp.Capabilities {
|
||||
rows = append(rows, []string{"", capability.String()})
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
if resp.ProjectorInfo != nil {
|
||||
tableRender("Projector", func() (rows [][]string) {
|
||||
arch := resp.ProjectorInfo["general.architecture"].(string)
|
||||
@@ -837,7 +852,7 @@ type runOptions struct {
|
||||
Format string
|
||||
System string
|
||||
Images []api.ImageData
|
||||
Options map[string]interface{}
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestShowInfo(t *testing.T) {
|
||||
@@ -260,6 +261,34 @@ Weigh anchor!
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("capabilities", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
if err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "test",
|
||||
ParameterSize: "7B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools},
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := " Model\n" +
|
||||
" architecture test \n" +
|
||||
" parameters 7B \n" +
|
||||
" quantization FP16 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
" vision \n" +
|
||||
" tools \n" +
|
||||
"\n"
|
||||
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteHandler(t *testing.T) {
|
||||
|
||||
@@ -1360,7 +1360,7 @@ func file_sentencepiece_model_proto_rawDescGZIP() []byte {
|
||||
|
||||
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
||||
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
||||
var file_sentencepiece_model_proto_goTypes = []interface{}{
|
||||
var file_sentencepiece_model_proto_goTypes = []any{
|
||||
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
|
||||
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
|
||||
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
|
||||
@@ -1392,7 +1392,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*TrainerSpec); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -1406,7 +1406,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*NormalizerSpec); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -1420,7 +1420,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*SelfTestData); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -1434,7 +1434,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*ModelProto); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -1448,7 +1448,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*SelfTestData_Sample); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -1460,7 +1460,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*ModelProto_SentencePiece); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
|
||||
@@ -12,7 +12,7 @@ func IsNUMA() bool {
|
||||
// numa support in llama.cpp is linux only
|
||||
return false
|
||||
}
|
||||
ids := map[string]interface{}{}
|
||||
ids := map[string]any{}
|
||||
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
|
||||
for _, packageId := range packageIds {
|
||||
id, err := os.ReadFile(packageId)
|
||||
|
||||
@@ -111,6 +111,7 @@ func GetCPUDetails() ([]CPU, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
return linuxCPUDetails(file)
|
||||
}
|
||||
|
||||
@@ -168,13 +169,11 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
|
||||
for id, s := range socketByID {
|
||||
s.CoreCount = len(coreBySocket[id])
|
||||
s.ThreadCount = 0
|
||||
for _, tc := range threadsByCoreBySocket[id] {
|
||||
s.ThreadCount += tc
|
||||
}
|
||||
|
||||
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
|
||||
efficiencyCoreCount := 0
|
||||
for _, threads := range threadsByCoreBySocket[id] {
|
||||
s.ThreadCount += threads
|
||||
if threads == 1 {
|
||||
efficiencyCoreCount++
|
||||
}
|
||||
|
||||
@@ -1217,7 +1217,7 @@ Show information about a model including details, modelfile, template, parameter
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/show -d '{
|
||||
"model": "llama3.2"
|
||||
"model": "llava"
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -1260,7 +1260,11 @@ curl http://localhost:11434/api/show -d '{
|
||||
"tokenizer.ggml.pre": "llama-bpe",
|
||||
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
|
||||
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
|
||||
}
|
||||
},
|
||||
"capabilities": [
|
||||
"completion",
|
||||
"vision"
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -20,7 +20,13 @@ Please refer to the [GPU docs](./gpu.md).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
|
||||
By default, Ollama uses a context window size of 2048 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
```shell
|
||||
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
|
||||
```
|
||||
|
||||
To change this when using `ollama run`, use `/set parameter`:
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
||||
On **Linux** systems with systemd, the logs can be found with this command:
|
||||
|
||||
```shell
|
||||
journalctl -u ollama --no-pager
|
||||
journalctl -u ollama --no-pager --follow --pager-end
|
||||
```
|
||||
|
||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func assertEqual(t *testing.T, a interface{}, b interface{}) {
|
||||
func assertEqual(t *testing.T, a any, b any) {
|
||||
if a != b {
|
||||
t.Errorf("Assert failed, expected %v, got %v", b, a)
|
||||
}
|
||||
|
||||
@@ -413,7 +413,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKV()
|
||||
@@ -426,7 +426,10 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
}
|
||||
|
||||
switch f.KV().Architecture() {
|
||||
case "llama":
|
||||
@@ -460,16 +463,14 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
case "mllama":
|
||||
var visionTokens, tiles uint64 = 1601, 4
|
||||
|
||||
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
||||
kv = headsKV *
|
||||
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
||||
(2* // sizeof(float16)
|
||||
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
||||
context +
|
||||
4* // sizeof(float32)
|
||||
uint64(crossAttentionLayers.size)* // num cross attention layers
|
||||
visionTokens*
|
||||
tiles)
|
||||
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
|
||||
for i := range kv {
|
||||
if slices.Contains(crossAttentionLayers, uint32(i)) {
|
||||
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
|
||||
4 * // sizeof(float32)
|
||||
visionTokens *
|
||||
tiles
|
||||
}
|
||||
}
|
||||
|
||||
fullOffload = max(
|
||||
@@ -505,6 +506,20 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
4*embeddingHeadsK*context*8+
|
||||
embedding*embeddingHeadsK*heads*9/16,
|
||||
)
|
||||
|
||||
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||
// engine. Gemma3 always uses the Ollama engine.
|
||||
if f.KV().Architecture() == "gemma3" {
|
||||
const gemma3GlobalCacheCount = 6
|
||||
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
|
||||
for i := range kv {
|
||||
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
||||
// layers are the smaller local (sliding) layers.
|
||||
if (i+1)%gemma3GlobalCacheCount != 0 {
|
||||
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "command-r":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
//go:build go1.24
|
||||
|
||||
package grammar
|
||||
|
||||
import "testing"
|
||||
|
||||
func BenchmarkFromSchema(b *testing.B) {
|
||||
for tt := range testCases(b) {
|
||||
b.Run("", func(b *testing.B) {
|
||||
s := []byte(tt.schema)
|
||||
|
||||
b.ReportAllocs()
|
||||
for b.Loop() {
|
||||
_, err := FromSchema(nil, s)
|
||||
if err != nil {
|
||||
b.Fatalf("GrammarFromSchema: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1,227 +0,0 @@
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"iter"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/grammar/jsonschema"
|
||||
)
|
||||
|
||||
const jsonTerms = `
|
||||
# Unicode
|
||||
#
|
||||
# Unicode characters can be specified directly in the grammar, for example
|
||||
# hiragana ::= [ぁ-ゟ], or with escapes: 8-bit (\xXX), 16-bit (\uXXXX) or 32-bit
|
||||
# (\UXXXXXXXX).
|
||||
unicode ::= \x{hex}{2} | \u{hex}{4} | \U{hex}{8}
|
||||
|
||||
# JSON grammar from RFC 7159
|
||||
null ::= "null"
|
||||
object ::= "{" (kv ("," kv)*)? "}"
|
||||
array ::= "[" (value ("," value)*)? "]"
|
||||
kv ::= string ":" value
|
||||
integer ::= "0" | [1-9] [0-9]*
|
||||
number ::= "-"? integer frac? exp?
|
||||
frac ::= "." [0-9]+
|
||||
exp ::= ("e" | "E") ("+" | "-") [0-9]+
|
||||
string ::= "\"" char* "\""
|
||||
escape ::= ["/" | "b" | "f" | "n" | "r" | "t" | unicode]
|
||||
char ::= [^"\\] | escape
|
||||
space ::= (" " | "\t" | "\n" | "\r")*
|
||||
hex ::= [0-9] | [a-f] | [A-F]
|
||||
boolean ::= "true" | "false"
|
||||
value ::= object | array | string | number | boolean | "null"
|
||||
|
||||
# User-defined
|
||||
`
|
||||
|
||||
// FromSchema generates a grammar from a JSON schema.
|
||||
func FromSchema(buf []byte, jsonSchema []byte) ([]byte, error) {
|
||||
var s *jsonschema.Schema
|
||||
if err := json.Unmarshal(jsonSchema, &s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var g builder
|
||||
|
||||
// "root" is the only rule that is guaranteed to exist, so we start
|
||||
// with its length for padding, and then adjust it as we go.
|
||||
g.pad = len("root")
|
||||
for id := range dependencies("root", s) {
|
||||
g.pad = max(g.pad, len(id))
|
||||
}
|
||||
|
||||
g.b.WriteString(jsonTerms)
|
||||
|
||||
ids := make(map[*jsonschema.Schema]string)
|
||||
for id, s := range dependencies("root", s) {
|
||||
ids[s] = id
|
||||
g.define(id)
|
||||
if err := fromSchema(&g, ids, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
g.define("root")
|
||||
if err := fromSchema(&g, ids, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
g.define("") // finalize the last rule
|
||||
return g.b.Bytes(), nil
|
||||
}
|
||||
|
||||
func fromSchema(g *builder, ids map[*jsonschema.Schema]string, s *jsonschema.Schema) error {
|
||||
switch typ := s.EffectiveType(); typ {
|
||||
case "array":
|
||||
if len(s.PrefixItems) == 0 && s.Items == nil {
|
||||
g.u("array")
|
||||
} else {
|
||||
g.q("[")
|
||||
for i, s := range s.PrefixItems {
|
||||
if i > 0 {
|
||||
g.q(",")
|
||||
}
|
||||
g.u(ids[s])
|
||||
}
|
||||
if s.Items != nil {
|
||||
g.u("(")
|
||||
if len(s.PrefixItems) > 0 {
|
||||
g.q(",")
|
||||
}
|
||||
g.u(ids[s.Items])
|
||||
g.u(")*")
|
||||
}
|
||||
g.q("]")
|
||||
}
|
||||
case "object":
|
||||
if len(s.Properties) == 0 {
|
||||
g.u("object")
|
||||
} else {
|
||||
g.q("{")
|
||||
for i, p := range s.Properties {
|
||||
name := ids[p]
|
||||
if i > 0 {
|
||||
g.q(",")
|
||||
}
|
||||
g.q(p.Name)
|
||||
g.q(":")
|
||||
g.u(name)
|
||||
}
|
||||
g.q("}")
|
||||
}
|
||||
case "number":
|
||||
buildConstrainedNumber(g, s)
|
||||
case "string":
|
||||
if len(s.Enum) == 0 {
|
||||
g.u("string")
|
||||
} else {
|
||||
g.u("(")
|
||||
for i, e := range s.Enum {
|
||||
if i > 0 {
|
||||
g.q("|")
|
||||
}
|
||||
g.q(string(e))
|
||||
}
|
||||
g.u(")")
|
||||
}
|
||||
case "boolean", "value", "null", "integer":
|
||||
g.u(typ)
|
||||
default:
|
||||
return fmt.Errorf("%s: unsupported type %q", s.Name, typ)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dependencies returns a sequence of all child dependencies of the schema in
|
||||
// post-order.
|
||||
//
|
||||
// The first value is the id/pointer to the dependency, and the second value
|
||||
// is the schema.
|
||||
func dependencies(id string, s *jsonschema.Schema) iter.Seq2[string, *jsonschema.Schema] {
|
||||
return func(yield func(string, *jsonschema.Schema) bool) {
|
||||
for i, p := range s.Properties {
|
||||
id := fmt.Sprintf("%s_%d", id, i)
|
||||
for did, d := range dependencies(id, p) {
|
||||
if !yield(did, d) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if !yield(id, p) {
|
||||
return
|
||||
}
|
||||
}
|
||||
for i, p := range s.PrefixItems {
|
||||
id := fmt.Sprintf("tuple_%d", i)
|
||||
for did, d := range dependencies(id, p) {
|
||||
id := fmt.Sprintf("%s_%s", id, did)
|
||||
if !yield(id, d) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if !yield(id, p) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if s.Items != nil {
|
||||
id := fmt.Sprintf("%s_tuple_%d", id, len(s.PrefixItems))
|
||||
for did, d := range dependencies(id, s.Items) {
|
||||
if !yield(did, d) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if !yield(id, s.Items) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type builder struct {
|
||||
b bytes.Buffer
|
||||
pad int
|
||||
rules int
|
||||
items int
|
||||
}
|
||||
|
||||
// define terminates the current rule, if any, and then either starts a new
|
||||
// rule or does nothing else if the name is empty.
|
||||
func (b *builder) define(name string) {
|
||||
if b.rules > 0 {
|
||||
b.b.WriteString(";\n")
|
||||
}
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(&b.b, "% -*s", b.pad, name)
|
||||
b.b.WriteString(" ::=")
|
||||
b.rules++
|
||||
b.items = 0
|
||||
}
|
||||
|
||||
// quote appends a terminal to the current rule.
|
||||
func (b *builder) q(s string) {
|
||||
if b.items > 0 {
|
||||
b.b.WriteString(" ")
|
||||
}
|
||||
b.b.WriteString(" ")
|
||||
b.b.WriteString(strconv.Quote(s))
|
||||
}
|
||||
|
||||
// u appends a non-terminal to the current rule.
|
||||
func (b *builder) u(s string) {
|
||||
if b.items > 0 {
|
||||
b.b.WriteString(" ")
|
||||
}
|
||||
b.b.WriteString(" ")
|
||||
b.b.WriteString(s)
|
||||
}
|
||||
|
||||
func buildConstrainedNumber(b *builder, s *jsonschema.Schema) {
|
||||
if s.Minimum == 0 && s.Maximum == 0 {
|
||||
b.u("TODO")
|
||||
} else {
|
||||
b.u("number")
|
||||
}
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"cmp"
|
||||
"iter"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/ollama/ollama/grammar/internal/diff"
|
||||
)
|
||||
|
||||
func TestFromSchema(t *testing.T) {
|
||||
for tt := range testCases(t) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g, err := FromSchema(nil, []byte(tt.schema))
|
||||
if err != nil {
|
||||
t.Fatalf("FromSchema: %v", err)
|
||||
}
|
||||
got := string(g)
|
||||
got = strings.TrimPrefix(got, jsonTerms)
|
||||
if got != tt.want {
|
||||
t.Logf("schema:\n%s", tt.schema)
|
||||
t.Fatal(string(diff.Diff("got", []byte(got), "want", []byte(tt.want))))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
schema string
|
||||
want string
|
||||
}
|
||||
|
||||
//go:embed testdata/schemas.txt
|
||||
var tests string
|
||||
|
||||
func testCases(t testing.TB) iter.Seq[testCase] {
|
||||
t.Helper()
|
||||
return func(yield func(testCase) bool) {
|
||||
t.Helper()
|
||||
sc := bufio.NewScanner(strings.NewReader(tests))
|
||||
name := ""
|
||||
for sc.Scan() {
|
||||
line := strings.TrimSpace(sc.Text())
|
||||
if line == "" {
|
||||
name = ""
|
||||
continue
|
||||
}
|
||||
if line[0] == '#' {
|
||||
name = cmp.Or(name, strings.TrimSpace(line[1:]))
|
||||
continue
|
||||
}
|
||||
s := sc.Text()
|
||||
g := ""
|
||||
for sc.Scan() {
|
||||
line = strings.TrimSpace(sc.Text())
|
||||
if line == "" || line[0] == '#' {
|
||||
break
|
||||
}
|
||||
g += sc.Text() + "\n"
|
||||
}
|
||||
if !yield(testCase{name, s, g}) {
|
||||
return
|
||||
}
|
||||
name = strings.TrimSpace(strings.TrimPrefix(line, "#"))
|
||||
}
|
||||
if err := sc.Err(); err != nil {
|
||||
t.Fatalf("error reading tests: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,261 +0,0 @@
|
||||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package diff
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A pair is a pair of values tracked for both the x and y side of a diff.
|
||||
// It is typically a pair of line indexes.
|
||||
type pair struct{ x, y int }
|
||||
|
||||
// Diff returns an anchored diff of the two texts old and new
|
||||
// in the “unified diff” format. If old and new are identical,
|
||||
// Diff returns a nil slice (no output).
|
||||
//
|
||||
// Unix diff implementations typically look for a diff with
|
||||
// the smallest number of lines inserted and removed,
|
||||
// which can in the worst case take time quadratic in the
|
||||
// number of lines in the texts. As a result, many implementations
|
||||
// either can be made to run for a long time or cut off the search
|
||||
// after a predetermined amount of work.
|
||||
//
|
||||
// In contrast, this implementation looks for a diff with the
|
||||
// smallest number of “unique” lines inserted and removed,
|
||||
// where unique means a line that appears just once in both old and new.
|
||||
// We call this an “anchored diff” because the unique lines anchor
|
||||
// the chosen matching regions. An anchored diff is usually clearer
|
||||
// than a standard diff, because the algorithm does not try to
|
||||
// reuse unrelated blank lines or closing braces.
|
||||
// The algorithm also guarantees to run in O(n log n) time
|
||||
// instead of the standard O(n²) time.
|
||||
//
|
||||
// Some systems call this approach a “patience diff,” named for
|
||||
// the “patience sorting” algorithm, itself named for a solitaire card game.
|
||||
// We avoid that name for two reasons. First, the name has been used
|
||||
// for a few different variants of the algorithm, so it is imprecise.
|
||||
// Second, the name is frequently interpreted as meaning that you have
|
||||
// to wait longer (to be patient) for the diff, meaning that it is a slower algorithm,
|
||||
// when in fact the algorithm is faster than the standard one.
|
||||
func Diff(oldName string, old []byte, newName string, new []byte) []byte {
|
||||
if bytes.Equal(old, new) {
|
||||
return nil
|
||||
}
|
||||
x := lines(old)
|
||||
y := lines(new)
|
||||
|
||||
// Print diff header.
|
||||
var out bytes.Buffer
|
||||
fmt.Fprintf(&out, "diff %s %s\n", oldName, newName)
|
||||
fmt.Fprintf(&out, "--- %s\n", oldName)
|
||||
fmt.Fprintf(&out, "+++ %s\n", newName)
|
||||
|
||||
// Loop over matches to consider,
|
||||
// expanding each match to include surrounding lines,
|
||||
// and then printing diff chunks.
|
||||
// To avoid setup/teardown cases outside the loop,
|
||||
// tgs returns a leading {0,0} and trailing {len(x), len(y)} pair
|
||||
// in the sequence of matches.
|
||||
var (
|
||||
done pair // printed up to x[:done.x] and y[:done.y]
|
||||
chunk pair // start lines of current chunk
|
||||
count pair // number of lines from each side in current chunk
|
||||
ctext []string // lines for current chunk
|
||||
)
|
||||
for _, m := range tgs(x, y) {
|
||||
if m.x < done.x {
|
||||
// Already handled scanning forward from earlier match.
|
||||
continue
|
||||
}
|
||||
|
||||
// Expand matching lines as far as possible,
|
||||
// establishing that x[start.x:end.x] == y[start.y:end.y].
|
||||
// Note that on the first (or last) iteration we may (or definitely do)
|
||||
// have an empty match: start.x==end.x and start.y==end.y.
|
||||
start := m
|
||||
for start.x > done.x && start.y > done.y && x[start.x-1] == y[start.y-1] {
|
||||
start.x--
|
||||
start.y--
|
||||
}
|
||||
end := m
|
||||
for end.x < len(x) && end.y < len(y) && x[end.x] == y[end.y] {
|
||||
end.x++
|
||||
end.y++
|
||||
}
|
||||
|
||||
// Emit the mismatched lines before start into this chunk.
|
||||
// (No effect on first sentinel iteration, when start = {0,0}.)
|
||||
for _, s := range x[done.x:start.x] {
|
||||
ctext = append(ctext, "-"+s)
|
||||
count.x++
|
||||
}
|
||||
for _, s := range y[done.y:start.y] {
|
||||
ctext = append(ctext, "+"+s)
|
||||
count.y++
|
||||
}
|
||||
|
||||
// If we're not at EOF and have too few common lines,
|
||||
// the chunk includes all the common lines and continues.
|
||||
const C = 3 // number of context lines
|
||||
if (end.x < len(x) || end.y < len(y)) &&
|
||||
(end.x-start.x < C || (len(ctext) > 0 && end.x-start.x < 2*C)) {
|
||||
for _, s := range x[start.x:end.x] {
|
||||
ctext = append(ctext, " "+s)
|
||||
count.x++
|
||||
count.y++
|
||||
}
|
||||
done = end
|
||||
continue
|
||||
}
|
||||
|
||||
// End chunk with common lines for context.
|
||||
if len(ctext) > 0 {
|
||||
n := end.x - start.x
|
||||
if n > C {
|
||||
n = C
|
||||
}
|
||||
for _, s := range x[start.x : start.x+n] {
|
||||
ctext = append(ctext, " "+s)
|
||||
count.x++
|
||||
count.y++
|
||||
}
|
||||
done = pair{start.x + n, start.y + n}
|
||||
|
||||
// Format and emit chunk.
|
||||
// Convert line numbers to 1-indexed.
|
||||
// Special case: empty file shows up as 0,0 not 1,0.
|
||||
if count.x > 0 {
|
||||
chunk.x++
|
||||
}
|
||||
if count.y > 0 {
|
||||
chunk.y++
|
||||
}
|
||||
fmt.Fprintf(&out, "@@ -%d,%d +%d,%d @@\n", chunk.x, count.x, chunk.y, count.y)
|
||||
for _, s := range ctext {
|
||||
out.WriteString(s)
|
||||
}
|
||||
count.x = 0
|
||||
count.y = 0
|
||||
ctext = ctext[:0]
|
||||
}
|
||||
|
||||
// If we reached EOF, we're done.
|
||||
if end.x >= len(x) && end.y >= len(y) {
|
||||
break
|
||||
}
|
||||
|
||||
// Otherwise start a new chunk.
|
||||
chunk = pair{end.x - C, end.y - C}
|
||||
for _, s := range x[chunk.x:end.x] {
|
||||
ctext = append(ctext, " "+s)
|
||||
count.x++
|
||||
count.y++
|
||||
}
|
||||
done = end
|
||||
}
|
||||
|
||||
return out.Bytes()
|
||||
}
|
||||
|
||||
// lines returns the lines in the file x, including newlines.
|
||||
// If the file does not end in a newline, one is supplied
|
||||
// along with a warning about the missing newline.
|
||||
func lines(x []byte) []string {
|
||||
l := strings.SplitAfter(string(x), "\n")
|
||||
if l[len(l)-1] == "" {
|
||||
l = l[:len(l)-1]
|
||||
} else {
|
||||
// Treat last line as having a message about the missing newline attached,
|
||||
// using the same text as BSD/GNU diff (including the leading backslash).
|
||||
l[len(l)-1] += "\n\\ No newline at end of file\n"
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// tgs returns the pairs of indexes of the longest common subsequence
|
||||
// of unique lines in x and y, where a unique line is one that appears
|
||||
// once in x and once in y.
|
||||
//
|
||||
// The longest common subsequence algorithm is as described in
|
||||
// Thomas G. Szymanski, “A Special Case of the Maximal Common
|
||||
// Subsequence Problem,” Princeton TR #170 (January 1975),
|
||||
// available at https://research.swtch.com/tgs170.pdf.
|
||||
func tgs(x, y []string) []pair {
|
||||
// Count the number of times each string appears in a and b.
|
||||
// We only care about 0, 1, many, counted as 0, -1, -2
|
||||
// for the x side and 0, -4, -8 for the y side.
|
||||
// Using negative numbers now lets us distinguish positive line numbers later.
|
||||
m := make(map[string]int)
|
||||
for _, s := range x {
|
||||
if c := m[s]; c > -2 {
|
||||
m[s] = c - 1
|
||||
}
|
||||
}
|
||||
for _, s := range y {
|
||||
if c := m[s]; c > -8 {
|
||||
m[s] = c - 4
|
||||
}
|
||||
}
|
||||
|
||||
// Now unique strings can be identified by m[s] = -1+-4.
|
||||
//
|
||||
// Gather the indexes of those strings in x and y, building:
|
||||
// xi[i] = increasing indexes of unique strings in x.
|
||||
// yi[i] = increasing indexes of unique strings in y.
|
||||
// inv[i] = index j such that x[xi[i]] = y[yi[j]].
|
||||
var xi, yi, inv []int
|
||||
for i, s := range y {
|
||||
if m[s] == -1+-4 {
|
||||
m[s] = len(yi)
|
||||
yi = append(yi, i)
|
||||
}
|
||||
}
|
||||
for i, s := range x {
|
||||
if j, ok := m[s]; ok && j >= 0 {
|
||||
xi = append(xi, i)
|
||||
inv = append(inv, j)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply Algorithm A from Szymanski's paper.
|
||||
// In those terms, A = J = inv and B = [0, n).
|
||||
// We add sentinel pairs {0,0}, and {len(x),len(y)}
|
||||
// to the returned sequence, to help the processing loop.
|
||||
J := inv
|
||||
n := len(xi)
|
||||
T := make([]int, n)
|
||||
L := make([]int, n)
|
||||
for i := range T {
|
||||
T[i] = n + 1
|
||||
}
|
||||
for i := range n {
|
||||
k := sort.Search(n, func(k int) bool {
|
||||
return T[k] >= J[i]
|
||||
})
|
||||
T[k] = J[i]
|
||||
L[i] = k + 1
|
||||
}
|
||||
k := 0
|
||||
for _, v := range L {
|
||||
if k < v {
|
||||
k = v
|
||||
}
|
||||
}
|
||||
seq := make([]pair, 2+k)
|
||||
seq[1+k] = pair{len(x), len(y)} // sentinel at end
|
||||
lastj := n
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
if L[i] == k && J[i] < lastj {
|
||||
seq[k] = pair{xi[i], yi[J[i]]}
|
||||
k--
|
||||
}
|
||||
}
|
||||
seq[0] = pair{0, 0} // sentinel at start
|
||||
return seq
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package diff
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/tools/txtar"
|
||||
)
|
||||
|
||||
func clean(text []byte) []byte {
|
||||
text = bytes.ReplaceAll(text, []byte("$\n"), []byte("\n"))
|
||||
text = bytes.TrimSuffix(text, []byte("^D\n"))
|
||||
return text
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
files, _ := filepath.Glob("testdata/*.txt")
|
||||
if len(files) == 0 {
|
||||
t.Fatalf("no testdata")
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
t.Run(filepath.Base(file), func(t *testing.T) {
|
||||
a, err := txtar.ParseFile(file)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(a.Files) != 3 || a.Files[2].Name != "diff" {
|
||||
t.Fatalf("%s: want three files, third named \"diff\"", file)
|
||||
}
|
||||
diffs := Diff(a.Files[0].Name, clean(a.Files[0].Data), a.Files[1].Name, clean(a.Files[1].Data))
|
||||
want := clean(a.Files[2].Data)
|
||||
if !bytes.Equal(diffs, want) {
|
||||
t.Fatalf("%s: have:\n%s\nwant:\n%s\n%s", file,
|
||||
diffs, want, Diff("have", diffs, "want", want))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
13
grammar/internal/diff/testdata/allnew.txt
vendored
13
grammar/internal/diff/testdata/allnew.txt
vendored
@@ -1,13 +0,0 @@
|
||||
-- old --
|
||||
-- new --
|
||||
a
|
||||
b
|
||||
c
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -0,0 +1,3 @@
|
||||
+a
|
||||
+b
|
||||
+c
|
||||
13
grammar/internal/diff/testdata/allold.txt
vendored
13
grammar/internal/diff/testdata/allold.txt
vendored
@@ -1,13 +0,0 @@
|
||||
-- old --
|
||||
a
|
||||
b
|
||||
c
|
||||
-- new --
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,3 +0,0 @@
|
||||
-a
|
||||
-b
|
||||
-c
|
||||
35
grammar/internal/diff/testdata/basic.txt
vendored
35
grammar/internal/diff/testdata/basic.txt
vendored
@@ -1,35 +0,0 @@
|
||||
Example from Hunt and McIlroy, “An Algorithm for Differential File Comparison.”
|
||||
https://www.cs.dartmouth.edu/~doug/diff.pdf
|
||||
|
||||
-- old --
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
-- new --
|
||||
w
|
||||
a
|
||||
b
|
||||
x
|
||||
y
|
||||
z
|
||||
e
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,7 +1,7 @@
|
||||
+w
|
||||
a
|
||||
b
|
||||
-c
|
||||
-d
|
||||
+x
|
||||
+y
|
||||
+z
|
||||
e
|
||||
-f
|
||||
-g
|
||||
40
grammar/internal/diff/testdata/dups.txt
vendored
40
grammar/internal/diff/testdata/dups.txt
vendored
@@ -1,40 +0,0 @@
|
||||
-- old --
|
||||
a
|
||||
|
||||
b
|
||||
|
||||
c
|
||||
|
||||
d
|
||||
|
||||
e
|
||||
|
||||
f
|
||||
-- new --
|
||||
a
|
||||
|
||||
B
|
||||
|
||||
C
|
||||
|
||||
d
|
||||
|
||||
e
|
||||
|
||||
f
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,8 +1,8 @@
|
||||
a
|
||||
$
|
||||
-b
|
||||
-
|
||||
-c
|
||||
+B
|
||||
+
|
||||
+C
|
||||
$
|
||||
d
|
||||
$
|
||||
38
grammar/internal/diff/testdata/end.txt
vendored
38
grammar/internal/diff/testdata/end.txt
vendored
@@ -1,38 +0,0 @@
|
||||
-- old --
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
eight
|
||||
nine
|
||||
ten
|
||||
eleven
|
||||
-- new --
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
10
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -5,7 +5,6 @@
|
||||
5
|
||||
6
|
||||
7
|
||||
-eight
|
||||
-nine
|
||||
-ten
|
||||
-eleven
|
||||
+8
|
||||
+9
|
||||
+10
|
||||
9
grammar/internal/diff/testdata/eof.txt
vendored
9
grammar/internal/diff/testdata/eof.txt
vendored
@@ -1,9 +0,0 @@
|
||||
-- old --
|
||||
a
|
||||
b
|
||||
c^D
|
||||
-- new --
|
||||
a
|
||||
b
|
||||
c^D
|
||||
-- diff --
|
||||
18
grammar/internal/diff/testdata/eof1.txt
vendored
18
grammar/internal/diff/testdata/eof1.txt
vendored
@@ -1,18 +0,0 @@
|
||||
-- old --
|
||||
a
|
||||
b
|
||||
c
|
||||
-- new --
|
||||
a
|
||||
b
|
||||
c^D
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,3 +1,3 @@
|
||||
a
|
||||
b
|
||||
-c
|
||||
+c
|
||||
\ No newline at end of file
|
||||
18
grammar/internal/diff/testdata/eof2.txt
vendored
18
grammar/internal/diff/testdata/eof2.txt
vendored
@@ -1,18 +0,0 @@
|
||||
-- old --
|
||||
a
|
||||
b
|
||||
c^D
|
||||
-- new --
|
||||
a
|
||||
b
|
||||
c
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,3 +1,3 @@
|
||||
a
|
||||
b
|
||||
-c
|
||||
\ No newline at end of file
|
||||
+c
|
||||
62
grammar/internal/diff/testdata/long.txt
vendored
62
grammar/internal/diff/testdata/long.txt
vendored
@@ -1,62 +0,0 @@
|
||||
-- old --
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
10
|
||||
11
|
||||
12
|
||||
13
|
||||
14
|
||||
14½
|
||||
15
|
||||
16
|
||||
17
|
||||
18
|
||||
19
|
||||
20
|
||||
-- new --
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
8
|
||||
9
|
||||
10
|
||||
11
|
||||
12
|
||||
13
|
||||
14
|
||||
17
|
||||
18
|
||||
19
|
||||
20
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -4,7 +4,6 @@
|
||||
4
|
||||
5
|
||||
6
|
||||
-7
|
||||
8
|
||||
9
|
||||
10
|
||||
@@ -12,9 +11,6 @@
|
||||
12
|
||||
13
|
||||
14
|
||||
-14½
|
||||
-15
|
||||
-16
|
||||
17
|
||||
18
|
||||
19
|
||||
5
grammar/internal/diff/testdata/same.txt
vendored
5
grammar/internal/diff/testdata/same.txt
vendored
@@ -1,5 +0,0 @@
|
||||
-- old --
|
||||
hello world
|
||||
-- new --
|
||||
hello world
|
||||
-- diff --
|
||||
34
grammar/internal/diff/testdata/start.txt
vendored
34
grammar/internal/diff/testdata/start.txt
vendored
@@ -1,34 +0,0 @@
|
||||
-- old --
|
||||
e
|
||||
pi
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
10
|
||||
-- new --
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
10
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,5 +1,6 @@
|
||||
-e
|
||||
-pi
|
||||
+1
|
||||
+2
|
||||
+3
|
||||
4
|
||||
5
|
||||
6
|
||||
40
grammar/internal/diff/testdata/triv.txt
vendored
40
grammar/internal/diff/testdata/triv.txt
vendored
@@ -1,40 +0,0 @@
|
||||
Another example from Hunt and McIlroy,
|
||||
“An Algorithm for Differential File Comparison.”
|
||||
https://www.cs.dartmouth.edu/~doug/diff.pdf
|
||||
|
||||
Anchored diff gives up on finding anything,
|
||||
since there are no unique lines.
|
||||
|
||||
-- old --
|
||||
a
|
||||
b
|
||||
c
|
||||
a
|
||||
b
|
||||
b
|
||||
a
|
||||
-- new --
|
||||
c
|
||||
a
|
||||
b
|
||||
a
|
||||
b
|
||||
c
|
||||
-- diff --
|
||||
diff old new
|
||||
--- old
|
||||
+++ new
|
||||
@@ -1,7 +1,6 @@
|
||||
-a
|
||||
-b
|
||||
-c
|
||||
-a
|
||||
-b
|
||||
-b
|
||||
-a
|
||||
+c
|
||||
+a
|
||||
+b
|
||||
+a
|
||||
+b
|
||||
+c
|
||||
@@ -1,171 +0,0 @@
|
||||
package jsonschema
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Schema holds a JSON schema.
|
||||
type Schema struct {
|
||||
// Name is the name of the property. For the parent/root property, this
|
||||
// is "root". For child properties, this is the name of the property.
|
||||
Name string `json:"-"`
|
||||
|
||||
// Type is the type of the property.
|
||||
//
|
||||
// TODO: Union types (e.g. make this a []string).
|
||||
Type string
|
||||
|
||||
// PrefixItems is a list of schemas for each item in a tuple. By
|
||||
// default, the tuple is "closed." unless Items is set to true or a
|
||||
// valid Schema.
|
||||
PrefixItems []*Schema
|
||||
|
||||
// Items is the schema for each item in a list.
|
||||
//
|
||||
// If it is missing, or its JSON value is "null" or "false", it is nil.
|
||||
// If the JSON value is "true", it is set to the empty Schema. If the
|
||||
// JSON value is an object, it will be decoded as a Schema.
|
||||
Items *Schema
|
||||
|
||||
// MinItems specifies the minimum number of items allowed in a list.
|
||||
MinItems int
|
||||
|
||||
// MaxItems specifies the maximum number of items allowed in a list.
|
||||
MaxItems int
|
||||
|
||||
// Properties is the schema for each property of an object.
|
||||
Properties []*Schema
|
||||
|
||||
// Format is the format of the property. This is used to validate the
|
||||
// property against a specific format.
|
||||
//
|
||||
// It is the callers responsibility to validate the property against
|
||||
// the format.
|
||||
Format string
|
||||
|
||||
// Minimum specifies the minimum value for numeric properties.
|
||||
Minimum float64
|
||||
|
||||
// Maximum specifies the maximum value for numeric properties.
|
||||
Maximum float64
|
||||
|
||||
// Enum is a list of valid values for the property.
|
||||
Enum []json.RawMessage
|
||||
}
|
||||
|
||||
func (s *Schema) UnmarshalJSON(data []byte) error {
|
||||
type S Schema
|
||||
w := struct {
|
||||
Properties props
|
||||
Items items
|
||||
*S
|
||||
}{
|
||||
S: (*S)(s),
|
||||
}
|
||||
if err := json.Unmarshal(data, &w); err != nil {
|
||||
return err
|
||||
}
|
||||
if w.Items.set {
|
||||
s.Items = &w.Items.Schema
|
||||
}
|
||||
s.Properties = w.Properties
|
||||
return nil
|
||||
}
|
||||
|
||||
type items struct {
|
||||
Schema
|
||||
set bool
|
||||
}
|
||||
|
||||
func (s *items) UnmarshalJSON(data []byte) error {
|
||||
switch b := data[0]; b {
|
||||
case 't':
|
||||
*s = items{set: true}
|
||||
case '{':
|
||||
type I items
|
||||
if err := json.Unmarshal(data, (*I)(s)); err != nil {
|
||||
return err
|
||||
}
|
||||
s.set = true
|
||||
case 'n', 'f':
|
||||
default:
|
||||
return errors.New("invalid Items")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EffectiveType returns the effective type of the schema. If the Type field is
|
||||
// not empty, it is returned; otherwise:
|
||||
//
|
||||
// - If the schema has both Properties and Items, it returns an empty string.
|
||||
// - If the schema has Properties, it returns "object".
|
||||
// - If the schema has Items, it returns "array".
|
||||
// - If the schema has neither Properties nor Items, it returns "value".
|
||||
//
|
||||
// The returned string is never empty.
|
||||
func (d *Schema) EffectiveType() string {
|
||||
if d.Type == "" {
|
||||
if len(d.Properties) > 0 {
|
||||
return "object"
|
||||
}
|
||||
if len(d.PrefixItems) > 0 || d.Items != nil {
|
||||
return "array"
|
||||
}
|
||||
return "value"
|
||||
}
|
||||
return d.Type
|
||||
}
|
||||
|
||||
// props is an ordered list of properties. The order of the properties
|
||||
// is the order in which they were defined in the schema.
|
||||
type props []*Schema
|
||||
|
||||
var _ json.Unmarshaler = (*props)(nil)
|
||||
|
||||
func (v *props) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
if data[0] != '{' {
|
||||
return errors.New("expected object")
|
||||
}
|
||||
|
||||
d := json.NewDecoder(bytes.NewReader(data))
|
||||
|
||||
// TODO(bmizerany): Consider DisallowUnknownFields. Currently, we, like
|
||||
// llama.cpp, ignore unknown fields, which could be lead to unexpected
|
||||
// behavior for clients of this package, since they may not be aware
|
||||
// that "additionalFields", "itemsPrefix", etc, are being ignored.
|
||||
//
|
||||
// For now, just do what llama.cpp does.
|
||||
|
||||
t, err := d.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t != json.Delim('{') {
|
||||
return errors.New("expected object")
|
||||
}
|
||||
for d.More() {
|
||||
// Use the first token (map key) as the property name, then
|
||||
// decode the rest of the object fields into a Schema and
|
||||
// append.
|
||||
t, err := d.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t == json.Delim('}') {
|
||||
return nil
|
||||
}
|
||||
s := &Schema{
|
||||
Name: t.(string),
|
||||
}
|
||||
if err := d.Decode(s); err != nil {
|
||||
return err
|
||||
}
|
||||
*v = append(*v, s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
package jsonschema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
const testSchemaBasic = `
|
||||
{
|
||||
"properties": {
|
||||
"tupleClosedEmpty": { "prefixItems": [] },
|
||||
"tupleClosedMissing": { "prefixItems": [{}] },
|
||||
"tupleClosedNull": { "prefixItems": [{}], "items": null },
|
||||
"tupleClosedFalse": { "prefixItems": [{}], "items": false },
|
||||
"tupleOpenTrue": { "prefixItems": [{}], "items": true },
|
||||
"tupleOpenEmpty": { "prefixItems": [{}], "items": {} },
|
||||
"tupleOpenTyped": { "prefixItems": [{}], "items": {"type": "boolean"} },
|
||||
"tupleOpenMax": { "prefixItems": [{}], "items": true, "maxItems": 3},
|
||||
|
||||
"array": { "items": {"type": "number"} },
|
||||
|
||||
"null": { "type": "null" },
|
||||
"string": { "type": "string" },
|
||||
"boolean": { "type": "boolean" }
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
func TestSchemaUnmarshal(t *testing.T) {
|
||||
var got *Schema
|
||||
if err := json.Unmarshal([]byte(testSchemaBasic), &got); err != nil {
|
||||
t.Fatalf("Unmarshal: %v", err)
|
||||
}
|
||||
want := &Schema{
|
||||
Properties: []*Schema{
|
||||
{Name: "tupleClosedEmpty", PrefixItems: []*Schema{}, Items: nil},
|
||||
{Name: "tupleClosedMissing", PrefixItems: []*Schema{{}}, Items: nil},
|
||||
{Name: "tupleClosedNull", PrefixItems: []*Schema{{}}, Items: nil},
|
||||
{Name: "tupleClosedFalse", PrefixItems: []*Schema{{}}, Items: nil},
|
||||
|
||||
{Name: "tupleOpenTrue", PrefixItems: []*Schema{{}}, Items: &Schema{}},
|
||||
{Name: "tupleOpenEmpty", PrefixItems: []*Schema{{}}, Items: &Schema{}},
|
||||
{Name: "tupleOpenTyped", PrefixItems: []*Schema{{}}, Items: &Schema{Type: "boolean"}},
|
||||
{Name: "tupleOpenMax", PrefixItems: []*Schema{{}}, Items: &Schema{}, MaxItems: 3},
|
||||
|
||||
{Name: "array", Items: &Schema{Type: "number"}},
|
||||
|
||||
{Name: "null", Type: "null"},
|
||||
{Name: "string", Type: "string"},
|
||||
{Name: "boolean", Type: "boolean"},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("(-want, +got)\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveType(t *testing.T) {
|
||||
const schema = `
|
||||
{"properties": {
|
||||
"o": {"type": "object"},
|
||||
"a": {"type": "array"},
|
||||
"n": {"type": "number"},
|
||||
"s": {"type": "string"},
|
||||
"z": {"type": "null"},
|
||||
"b": {"type": "boolean"},
|
||||
|
||||
"t0": {"prefixItems": [{}], "items": {"type": "number"}},
|
||||
"t1": {"items": {"type": "number"}, "maxItems": 3},
|
||||
|
||||
"v": {"maxItems": 3}
|
||||
}}
|
||||
`
|
||||
|
||||
var s *Schema
|
||||
if err := json.Unmarshal([]byte(schema), &s); err != nil {
|
||||
t.Fatalf("json.Unmarshal: %v", err)
|
||||
}
|
||||
|
||||
var got []string
|
||||
for _, p := range s.Properties {
|
||||
got = append(got, p.EffectiveType())
|
||||
}
|
||||
|
||||
want := strings.Fields(`
|
||||
object
|
||||
array
|
||||
number
|
||||
string
|
||||
null
|
||||
boolean
|
||||
array
|
||||
array
|
||||
value
|
||||
`)
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("\ngot:\n\t%v\nwant:\n\t%v", got, want)
|
||||
}
|
||||
}
|
||||
76
grammar/testdata/schemas.txt
vendored
76
grammar/testdata/schemas.txt
vendored
@@ -1,76 +0,0 @@
|
||||
# This file holds tests for JSON schema to EBNF grammar conversions.
|
||||
#
|
||||
# The format is a JSON schema, followed by the expected EBNF grammar. Each test
|
||||
# MAY be preceded by a comment that describes the test (e.g. the test name), followed by
|
||||
# the JSON schema and the expected EBNF grammar. If no comment is present, the test
|
||||
# name the tests number in the file (e.g. "#0", "#1", etc.)
|
||||
#
|
||||
# Blank lines signify the end or start of a new test. Comments can be added
|
||||
# anywhere in the file, but they must be preceded by a '#' character and start at
|
||||
# the beginning of the line.
|
||||
|
||||
# default
|
||||
{}
|
||||
root ::= value;
|
||||
|
||||
{"properties": {}}
|
||||
root ::= value;
|
||||
|
||||
# array
|
||||
{"properties": {"a": {"type": "array", "items": {"type": "string"}}}}
|
||||
root_0_tuple_0 ::= string;
|
||||
root_0 ::= "[" ( root_0_tuple_0 )* "]";
|
||||
root ::= "{" "a" ":" root_0 "}";
|
||||
|
||||
# array with nested array
|
||||
{"type": "array", "items": {"type": "array", "items": {"type": "string"}}}
|
||||
root_tuple_0_tuple_0 ::= string;
|
||||
root_tuple_0 ::= "[" ( root_tuple_0_tuple_0 )* "]";
|
||||
root ::= "[" ( root_tuple_0 )* "]";
|
||||
|
||||
# object
|
||||
{"properties": {"e": {}}}
|
||||
root_0 ::= value;
|
||||
root ::= "{" "e" ":" root_0 "}";
|
||||
|
||||
# object with nested object
|
||||
{"properties": {"o": {"type": "object", "properties": {"e": {}}}}}
|
||||
root_0_0 ::= value;
|
||||
root_0 ::= "{" "e" ":" root_0_0 "}";
|
||||
root ::= "{" "o" ":" root_0 "}";
|
||||
|
||||
# boolean
|
||||
{"type": "boolean"}
|
||||
root ::= boolean;
|
||||
|
||||
# number
|
||||
{"properties": {"n": {"type": "number", "minimum": 123, "maximum": 4567}}}
|
||||
root_0 ::= number;
|
||||
root ::= "{" "n" ":" root_0 "}";
|
||||
|
||||
# string
|
||||
{"type": "string"}
|
||||
root ::= string;
|
||||
|
||||
# string with enum
|
||||
{"type": "string", "enum": ["a", "b", "c"]}
|
||||
root ::= ( "\"a\"" "|" "\"b\"" "|" "\"c\"" );
|
||||
|
||||
# spaces in key
|
||||
{"properties": {"a b": {}}}
|
||||
root_0 ::= value;
|
||||
root ::= "{" "a b" ":" root_0 "}";
|
||||
|
||||
# issue7978
|
||||
{ "type": "object", "properties": { "steps": { "type": "array", "items": { "type": "object", "properties": { "explanation": { "type": "string" }, "output": { "type": "string" } }, "required": [ "explanation", "output" ], "additionalProperties": false } }, "final_answer": { "type": "string" } }, "required": [ "steps", "final_answer" ], "additionalProperties": false }
|
||||
root_0_tuple_0_0 ::= string;
|
||||
root_0_tuple_0_1 ::= string;
|
||||
root_0_tuple_0 ::= "{" "explanation" ":" root_0_tuple_0_0 "," "output" ":" root_0_tuple_0_1 "}";
|
||||
root_0 ::= "[" ( root_0_tuple_0 )* "]";
|
||||
root_1 ::= string;
|
||||
root ::= "{" "steps" ":" root_0 "," "final_answer" ":" root_1 "}";
|
||||
|
||||
# !! # special characters in key
|
||||
# !! {"properties": {"a!b": {}}}
|
||||
# !! !invalid character '!' in key
|
||||
# !!
|
||||
@@ -22,7 +22,7 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the sky blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
@@ -39,7 +39,7 @@ func TestUnicode(t *testing.T) {
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
||||
Prompt: "天空为什么是蓝色的?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
// Workaround deepseek context shifting bug
|
||||
@@ -61,7 +61,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
Model: "gemma2:2b",
|
||||
Prompt: "Output some smily face emoji",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
@@ -96,7 +96,7 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the sky blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
|
||||
@@ -25,7 +25,7 @@ func TestMultiModelConcurrency(t *testing.T) {
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -34,7 +34,7 @@ func TestMultiModelConcurrency(t *testing.T) {
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
Model: "llama2",
|
||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": 128,
|
||||
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
Model: "llama2",
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": 128,
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestIntegrationLlava(t *testing.T) {
|
||||
Model: "llava:7b",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -47,7 +47,7 @@ func TestIntegrationMllama(t *testing.T) {
|
||||
Model: "x/llama3.2-vision",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -75,7 +75,7 @@ func TestIntegrationSplitBatch(t *testing.T) {
|
||||
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
|
||||
@@ -20,7 +20,7 @@ var (
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -28,7 +28,7 @@ var (
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestMaxQueue(t *testing.T) {
|
||||
req := api.GenerateRequest{
|
||||
Model: "orca-mini",
|
||||
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
|
||||
@@ -291,7 +291,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -300,7 +300,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
Prompt: "why is the color of dirt brown?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -309,7 +309,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -318,7 +318,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
Prompt: "what is the origin of independence day?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -327,7 +327,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
Prompt: "what is the composition of air?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
|
||||
@@ -62,6 +62,11 @@ type Cache interface {
|
||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||
|
||||
// CanResume returns true if the cache can continue with the next token at
|
||||
// the given position and sequence. Assumes that the caller has already
|
||||
// verified the contents of the cache.
|
||||
CanResume(seq int, pos int32) bool
|
||||
|
||||
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||
//
|
||||
|
||||
@@ -119,10 +119,10 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
}
|
||||
|
||||
var cacheSize int
|
||||
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
|
||||
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
|
||||
cacheSize = maxSequences * capacity
|
||||
} else {
|
||||
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
|
||||
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
|
||||
}
|
||||
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||
c.cells = make([]cacheCell, cacheSize)
|
||||
@@ -581,6 +581,35 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
c.cellRanges[dstSeq] = seqRange
|
||||
}
|
||||
|
||||
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
if c.windowSize == math.MaxInt32 {
|
||||
return true
|
||||
}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// for sliding window, check that the window of the new sequence is contained in
|
||||
// the window of what we are storing
|
||||
var last int32 = -1
|
||||
for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
last = max(last, c.cells[i].pos)
|
||||
}
|
||||
}
|
||||
|
||||
if last == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
lastWindowStart := max(0, last-c.windowSize)
|
||||
posWindowStart := max(0, pos-c.windowSize)
|
||||
|
||||
return posWindowStart >= lastWindowStart
|
||||
}
|
||||
|
||||
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
if c.shiftFn == nil {
|
||||
return ErrNotSupported
|
||||
@@ -635,6 +664,12 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
}
|
||||
|
||||
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// TODO(jessegross): We should check to see if removing the middle of the sequence will
|
||||
// cause the sliding window to encompass tokens that we no longer have. If so, then we
|
||||
// should return an error, which will trigger the runner to evaluate the full history and
|
||||
// rebuild the window. However, if we have multimodal inputs in our history, this reuse
|
||||
// results in use after free, so we don't do it for now.
|
||||
|
||||
var offset int32
|
||||
if endIndex != math.MaxInt32 {
|
||||
offset = beginIndex - endIndex
|
||||
@@ -649,8 +684,7 @@ func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
} else {
|
||||
if c.cells[i].pos >= endIndex {
|
||||
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
||||
// TODO(jessegross): Need to be careful about data shared between sequences
|
||||
return errors.New("shifting on cells shared by multiple sequences not yet implemented")
|
||||
return errors.New("shifting cells shared by multiple sequences not supported")
|
||||
}
|
||||
|
||||
c.cells[i].pos += offset
|
||||
|
||||
@@ -300,6 +300,77 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanResume(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
windowSize := int32(4)
|
||||
cache := NewSWACache(windowSize, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
context := backend.NewContext()
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3},
|
||||
Sequences: []int{0, 0, 0, 0},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// with window size 4, nothing has slid out of the window yet
|
||||
if !cache.CanResume(0, 0) {
|
||||
t.Errorf("CanResume(0, 0) = false, want true (within window)")
|
||||
}
|
||||
if !cache.CanResume(0, 1) {
|
||||
t.Errorf("CanResume(0, 1) = false, want true (within window)")
|
||||
}
|
||||
if !cache.CanResume(0, 2) {
|
||||
t.Errorf("CanResume(0, 2) = false, want true (within window)")
|
||||
}
|
||||
if !cache.CanResume(0, 3) {
|
||||
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||
}
|
||||
|
||||
// shift window by adding position 4
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{4, 5},
|
||||
Sequences: []int{0, 0},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
if cache.CanResume(0, 0) {
|
||||
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 1) {
|
||||
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 2) {
|
||||
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 3) {
|
||||
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 4) {
|
||||
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
}
|
||||
if !cache.CanResume(0, 5) {
|
||||
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
||||
}
|
||||
}
|
||||
|
||||
type testBackend struct{}
|
||||
|
||||
func (b *testBackend) Config() ml.Config {
|
||||
@@ -362,7 +433,6 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
func (c *testContext) Input() ml.Context { return c }
|
||||
func (c *testContext) Output() ml.Context { return c }
|
||||
func (c *testContext) Layer(int) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
@@ -134,6 +134,10 @@ func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("encoder cache does not support multiple sequences")
|
||||
}
|
||||
|
||||
func (c *EncoderCache) CanResume(seq int, pos int32) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||
c.encoderCached = false
|
||||
|
||||
@@ -87,6 +87,16 @@ func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) CanResume(seq int, pos int32) bool {
|
||||
for _, cache := range c.caches {
|
||||
if !cache.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||
for _, cache := range c.caches {
|
||||
|
||||
@@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() {
|
||||
C.llama_kv_cache_defrag(c.c)
|
||||
}
|
||||
|
||||
func (c *Context) KvCacheCanShift() bool {
|
||||
return bool(C.llama_kv_cache_can_shift(c.c))
|
||||
}
|
||||
|
||||
// Get the embeddings for a sequence id
|
||||
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
|
||||
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
|
||||
|
||||
103
llama/patches/0022-add-rdna4-support.patch
Normal file
103
llama/patches/0022-add-rdna4-support.patch
Normal file
@@ -0,0 +1,103 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Saman <saman.khatir@amd.com>
|
||||
Date: Wed, 19 Mar 2025 14:02:26 -0700
|
||||
Subject: [PATCH] add rdna4 support
|
||||
|
||||
---
|
||||
ggml/src/ggml-cuda/common.cuh | 6 ++++--
|
||||
ggml/src/ggml-cuda/mmq.cu | 2 +-
|
||||
ggml/src/ggml-cuda/mmq.cuh | 4 ++--
|
||||
ggml/src/ggml-cuda/mmvq.cu | 4 ++--
|
||||
ggml/src/ggml-cuda/vendors/hip.h | 4 ++++
|
||||
5 files changed, 13 insertions(+), 7 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
|
||||
index adf0d3ec..b24593fc 100644
|
||||
--- a/ggml/src/ggml-cuda/common.cuh
|
||||
+++ b/ggml/src/ggml-cuda/common.cuh
|
||||
@@ -61,11 +61,13 @@
|
||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||
+#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
||||
|
||||
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||
-#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
||||
+#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||
+#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||
|
||||
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
||||
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||
-#elif defined(RDNA3)
|
||||
+#elif defined(RDNA3) || defined(RDNA4)
|
||||
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||
int tmp1;
|
||||
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
|
||||
index 10f2ebb1..933d945c 100644
|
||||
--- a/ggml/src/ggml-cuda/mmq.cu
|
||||
+++ b/ggml/src/ggml-cuda/mmq.cu
|
||||
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
- return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
+ return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
|
||||
index 0451c65f..66ce2bc9 100644
|
||||
--- a/ggml/src/ggml-cuda/mmq.cuh
|
||||
+++ b/ggml/src/ggml-cuda/mmq.cuh
|
||||
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
|
||||
|
||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
-#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||
-#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
#else
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
|
||||
index 4fb466ca..23ae7abc 100644
|
||||
--- a/ggml/src/ggml-cuda/mmvq.cu
|
||||
+++ b/ggml/src/ggml-cuda/mmvq.cu
|
||||
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
||||
|
||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||
|
||||
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
|
||||
constexpr int nwarps = 1;
|
||||
constexpr int rows_per_cuda_block = 1;
|
||||
#else
|
||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
|
||||
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
index 81964611..a62544b5 100644
|
||||
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
@@ -150,6 +150,10 @@
|
||||
#define CDNA
|
||||
#endif
|
||||
|
||||
+#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
+#define RDNA4
|
||||
+#endif
|
||||
+
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||
defined(__gfx1150__) || defined(__gfx1151__)
|
||||
#define RDNA3
|
||||
@@ -15,12 +15,12 @@ import (
|
||||
)
|
||||
|
||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
|
||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||
// Split up the GPUs by type and try them
|
||||
var estimatedVRAM uint64
|
||||
for _, gpus := range allGpus.ByLibrary() {
|
||||
var layerCount int
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
||||
if opts.NumGPU < 0 {
|
||||
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
||||
@@ -71,7 +71,7 @@ type MemoryEstimate struct {
|
||||
|
||||
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
||||
// The GPUs provided must all be the same Library
|
||||
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options) MemoryEstimate {
|
||||
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
|
||||
// Graph size for a partial offload, applies to all GPUs
|
||||
var graphPartialOffload uint64
|
||||
|
||||
@@ -137,13 +137,19 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
}
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
layerSize += kv / f.KV().BlockCount()
|
||||
if len(kv) > 0 {
|
||||
layerSize += kv[0]
|
||||
}
|
||||
|
||||
var kvTotal uint64
|
||||
for _, kvLayer := range kv {
|
||||
kvTotal += kvLayer
|
||||
}
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = f.KV().GQA() * kv / 6
|
||||
graphPartialOffload = f.KV().GQA() * kvTotal / 6
|
||||
}
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
@@ -217,7 +223,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
// Some models have inconsistent layer sizes
|
||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||
layerSize = blk.Size()
|
||||
layerSize += kv / f.KV().BlockCount()
|
||||
layerSize += kv[i]
|
||||
memoryWeights += blk.Size()
|
||||
}
|
||||
|
||||
@@ -315,7 +321,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
layersRequested: opts.NumGPU,
|
||||
layersModel: int(f.KV().BlockCount()) + 1,
|
||||
availableList: availableList,
|
||||
kv: kv,
|
||||
kv: kvTotal,
|
||||
allocationsList: allocationsList,
|
||||
memoryWeights: memoryWeights,
|
||||
memoryLayerOutput: memoryLayerOutput,
|
||||
@@ -374,7 +380,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
|
||||
slog.Group(
|
||||
"weights",
|
||||
// memory of the weights
|
||||
"total", format.HumanBytes2(m.memoryWeights),
|
||||
"total", format.HumanBytes2(m.memoryWeights+m.memoryLayerOutput),
|
||||
// memory of repeating layers
|
||||
"repeating", format.HumanBytes2(m.memoryWeights),
|
||||
// memory of non-repeating layers
|
||||
|
||||
@@ -61,7 +61,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
||||
projectors := []string{}
|
||||
opts := api.DefaultOptions()
|
||||
t.Run("cpu", func(t *testing.T) {
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||
assert.Equal(t, 0, estimate.Layers)
|
||||
assert.Equal(t, uint64(0), estimate.Graph)
|
||||
})
|
||||
@@ -112,7 +112,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
||||
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
||||
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
|
||||
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
|
||||
var layerSums uint64
|
||||
|
||||
@@ -29,7 +29,6 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/grammar"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
@@ -110,7 +109,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
gpus = discover.GetCPUInfo()
|
||||
}
|
||||
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
||||
switch {
|
||||
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
|
||||
@@ -701,9 +700,9 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
// User provided a JSON schema
|
||||
g, err := grammar.FromSchema(nil, req.Format)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid JSON schema in format: %w", err)
|
||||
g := llama.SchemaToGrammar(req.Format)
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
}
|
||||
req.Grammar = string(g)
|
||||
}
|
||||
@@ -714,11 +713,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
@@ -733,6 +727,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
|
||||
req.Options.NumPredict = 10 * s.options.NumCtx
|
||||
}
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatusRetry(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -110,12 +110,10 @@ type Context interface {
|
||||
MaxGraphNodes() int
|
||||
Close()
|
||||
|
||||
// Input returns a context appropriate for creating input tensors
|
||||
// Input returns a context appropriate for creating tensors that are
|
||||
// inputs to the model (which includes things like output locations)
|
||||
Input() Context
|
||||
|
||||
// Output returns a context appropriate for creating output tensors
|
||||
Output() Context
|
||||
|
||||
// Layer returns a context appropriate for creating intermediate tensors
|
||||
Layer(int) Context
|
||||
}
|
||||
|
||||
@@ -48,9 +48,6 @@ type Backend struct {
|
||||
// input is the backend used for inputs
|
||||
input *C.struct_ggml_backend_buffer_type
|
||||
|
||||
// output is the backend used for outputs
|
||||
output *C.struct_ggml_backend_buffer_type
|
||||
|
||||
// layers is the backend used for repeating layers
|
||||
layers map[int]*C.struct_ggml_backend_buffer_type
|
||||
|
||||
@@ -400,8 +397,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||
C.size_t(maxGraphNodes),
|
||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||
),
|
||||
input: deviceBufferTypes[input.d],
|
||||
output: deviceBufferTypes[output.d],
|
||||
input: deviceBufferTypes[input.d],
|
||||
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
|
||||
m := make(map[int]*C.struct_ggml_backend_buffer_type)
|
||||
for i, layer := range layers {
|
||||
@@ -482,19 +478,6 @@ func (c Context) Input() ml.Context {
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c Context) Output() ml.Context {
|
||||
if c.b.output != nil {
|
||||
return &Context{
|
||||
b: c.b,
|
||||
ctx: c.ctx,
|
||||
buft: c.b.output,
|
||||
maxGraphNodes: c.maxGraphNodes,
|
||||
}
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c Context) Layer(i int) ml.Context {
|
||||
if buft, ok := c.b.layers[i]; ok {
|
||||
return &Context{
|
||||
|
||||
@@ -61,11 +61,13 @@
|
||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||
#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
||||
|
||||
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||
|
||||
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
||||
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||
#elif defined(RDNA3)
|
||||
#elif defined(RDNA3) || defined(RDNA4)
|
||||
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||
int tmp1;
|
||||
|
||||
2
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
vendored
2
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
vendored
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
4
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
vendored
4
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
vendored
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
|
||||
|
||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
#else
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||
|
||||
4
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
vendored
4
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
vendored
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
||||
|
||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
|
||||
constexpr int nwarps = 1;
|
||||
constexpr int rows_per_cuda_block = 1;
|
||||
#else
|
||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
|
||||
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||
|
||||
@@ -150,6 +150,10 @@
|
||||
#define CDNA
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#define RDNA4
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||
defined(__gfx1150__) || defined(__gfx1151__)
|
||||
#define RDNA3
|
||||
|
||||
@@ -38,7 +38,6 @@ const (
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
|
||||
@@ -55,7 +55,6 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
|
||||
@@ -45,7 +45,6 @@ func newTextModel(c ml.Config) *TextModel {
|
||||
|
||||
m := TextModel{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
|
||||
@@ -32,7 +32,6 @@ type TextProcessor interface {
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
Vocab() *Vocabulary
|
||||
}
|
||||
|
||||
type Vocabulary struct {
|
||||
|
||||
@@ -1,29 +1,23 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
||||
)
|
||||
|
||||
const spmWhitespaceSep = "▁"
|
||||
|
||||
func replaceWhitespaceBySeperator(s string) string {
|
||||
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
||||
}
|
||||
|
||||
type SentencePieceModel struct {
|
||||
maxTokenLen int
|
||||
pre *regexp2.Regexp
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||
|
||||
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
@@ -44,7 +38,6 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
||||
|
||||
return SentencePieceModel{
|
||||
maxTokenLen: maxTokenLen,
|
||||
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
@@ -53,24 +46,9 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
||||
return spm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Vocab() *Vocabulary {
|
||||
return spm.vocab
|
||||
}
|
||||
|
||||
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
||||
if !yield(m.String()) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := spm.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
@@ -95,7 +73,6 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
slog.Debug("fragments", "frags", fragments)
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
@@ -104,105 +81,96 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
||||
continue
|
||||
}
|
||||
|
||||
for split := range spm.split(frag.value) {
|
||||
split = replaceWhitespaceBySeperator(split)
|
||||
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
|
||||
|
||||
var sb strings.Builder
|
||||
sb.Write([]byte(split))
|
||||
if id := spm.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
if id := spm.vocab.Encode(text); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
q := &queue{}
|
||||
heap.Init(q)
|
||||
|
||||
runes := []rune(text)
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
pq := queue.NewWith(func(a, b any) int {
|
||||
priA := a.(*candidate)
|
||||
priB := b.(*candidate)
|
||||
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
})
|
||||
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("tokenizer", "merges", merges)
|
||||
|
||||
pairwise := func(a, b int) *candidate {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
if id := spm.vocab.Encode(left + right); id >= 0 {
|
||||
return &candidate{
|
||||
a: a,
|
||||
b: b,
|
||||
score: spm.vocab.Scores[id],
|
||||
}
|
||||
}
|
||||
pairwise := func(a, b int) *candidate {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pq.Enqueue(pair)
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
if id := spm.vocab.Encode(left + right); id >= 0 {
|
||||
return &candidate{
|
||||
a: a,
|
||||
b: b,
|
||||
score: spm.vocab.Scores[id],
|
||||
size: len(left) + len(right),
|
||||
}
|
||||
}
|
||||
|
||||
pqv := pq.Values()
|
||||
for _, v := range pqv {
|
||||
e := v.(*candidate)
|
||||
slog.Debug("candidate", "candidate", e)
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for q.Len() > 0 {
|
||||
pair := heap.Pop(q).(*candidate)
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
|
||||
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
|
||||
continue
|
||||
}
|
||||
|
||||
for !pq.Empty() {
|
||||
v, _ := pq.Dequeue()
|
||||
pair := v.(*candidate)
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
slog.Debug("pair", "left", left, "right", right)
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 {
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if token := string(merge.runes); token != "" {
|
||||
id := spm.vocab.Encode(token)
|
||||
|
||||
if id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
pq.Enqueue(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
pq.Enqueue(pair)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("merges", "merges", merges)
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
// Fallback to byte tokenization
|
||||
var result []int32
|
||||
for _, b := range []byte(token) {
|
||||
byteToken := fmt.Sprintf("<0x%02X>", b)
|
||||
unknownID := spm.vocab.Encode(byteToken)
|
||||
if unknownID >= 0 {
|
||||
result = append(result, unknownID)
|
||||
} else {
|
||||
slog.Debug("missing token", "token", string(merge.runes))
|
||||
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
|
||||
}
|
||||
}
|
||||
|
||||
ids = append(ids, result...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -233,6 +201,30 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
||||
type candidate struct {
|
||||
a, b int
|
||||
score float32
|
||||
size int
|
||||
}
|
||||
|
||||
type queue []*candidate
|
||||
|
||||
func (q queue) Len() int { return len(q) }
|
||||
|
||||
func (q queue) Less(i, j int) bool {
|
||||
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
|
||||
}
|
||||
|
||||
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
||||
|
||||
func (q *queue) Push(x interface{}) {
|
||||
item := x.(*candidate)
|
||||
*q = append(*q, item)
|
||||
}
|
||||
|
||||
func (q *queue) Pop() interface{} {
|
||||
old := *q
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
*q = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||
@@ -240,11 +232,26 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||
for _, id := range ids {
|
||||
data := spm.vocab.Decode(id)
|
||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||
if _, err := sb.WriteString(data); err != nil {
|
||||
return "", err
|
||||
|
||||
// For tokenizers that use byte tokens like "<0xEA>"
|
||||
// convert them to the partial unicode character
|
||||
// so they are buffered correctly by the runner instead
|
||||
// of being sent back to the api as "<0xEA>"
|
||||
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
||||
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse hex byte: %v", err)
|
||||
}
|
||||
|
||||
if err := sb.WriteByte(byte(byteVal)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
if _, err := sb.WriteString(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("decoded", "ids", ids, "text", sb.String())
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
@@ -25,8 +25,6 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
|
||||
|
||||
var v Vocabulary
|
||||
|
||||
for _, piece := range spm.GetPieces() {
|
||||
@@ -47,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||||
}
|
||||
}
|
||||
|
||||
return NewSentencePieceModel(preTokenizer, &v)
|
||||
return NewSentencePieceModel(&v)
|
||||
}
|
||||
|
||||
func TestSentencePieceEncode(t *testing.T) {
|
||||
@@ -116,3 +114,59 @@ func TestSentencePieceEncode(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{
|
||||
"normal",
|
||||
"<0xEA>",
|
||||
"<0x41>",
|
||||
"<0xC3>",
|
||||
"<0xA3>",
|
||||
},
|
||||
Types: []uint32{
|
||||
TOKEN_TYPE_NORMAL,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
},
|
||||
Scores: []float32{0, 0, 0, 0, 0},
|
||||
}
|
||||
|
||||
spm := NewSentencePieceModel(vocab)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ids []int32
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single byte token",
|
||||
ids: []int32{1},
|
||||
expected: "\xea",
|
||||
},
|
||||
{
|
||||
name: "ASCII byte token",
|
||||
ids: []int32{2},
|
||||
expected: "A",
|
||||
},
|
||||
{
|
||||
name: "multiple byte tokens forming UTF-8 character",
|
||||
ids: []int32{3, 4},
|
||||
expected: "ã",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := spm.Decode(tt.ids)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,10 +23,10 @@ import (
|
||||
var finishReasonToolCalls = "tool_calls"
|
||||
|
||||
type Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param interface{} `json:"param"`
|
||||
Code *string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param any `json:"param"`
|
||||
Code *string `json:"code"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
@@ -465,7 +465,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
options := make(map[string]interface{})
|
||||
options := make(map[string]any)
|
||||
|
||||
switch stop := r.Stop.(type) {
|
||||
case string:
|
||||
|
||||
@@ -219,7 +219,7 @@ func TestChatMiddleware(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]interface{}{
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
|
||||
@@ -213,8 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
||||
return discard
|
||||
}
|
||||
|
||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
type ErrReprocessInputs struct {
|
||||
Inputs []input
|
||||
}
|
||||
|
||||
func (e *ErrReprocessInputs) Error() string {
|
||||
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
|
||||
}
|
||||
|
||||
// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
|
||||
// and shifting the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
||||
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
@@ -222,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||
}
|
||||
|
||||
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
|
||||
inputLen := len(slot.Inputs)
|
||||
discard := c.ShiftDiscard(inputLen, numKeep)
|
||||
|
||||
if discard <= 0 {
|
||||
return nil
|
||||
@@ -231,16 +240,42 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||
"keep", numKeep, "discard", discard)
|
||||
|
||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||||
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
|
||||
}
|
||||
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
|
||||
var shiftFailed bool
|
||||
|
||||
for i := numKeep + discard; i < len(slot.Inputs); i++ {
|
||||
if c.lc.KvCacheCanShift() {
|
||||
// For models that support shifting, attempt to shift the KV cache
|
||||
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
||||
shiftFailed = true
|
||||
slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id)
|
||||
} else {
|
||||
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard)
|
||||
}
|
||||
} else {
|
||||
// For models that don't support shifting
|
||||
shiftFailed = true
|
||||
slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id)
|
||||
}
|
||||
|
||||
if shiftFailed {
|
||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||
newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
|
||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||
|
||||
// Clear the entire KV cache
|
||||
_ = c.lc.KvCacheSeqRm(slot.Id, 0, -1)
|
||||
// Reset the slot inputs since we've cleared the cache
|
||||
slot.Inputs = []input{}
|
||||
|
||||
// Return error with inputs that need to be reprocessed
|
||||
return &ErrReprocessInputs{Inputs: newInputs}
|
||||
}
|
||||
|
||||
// Standard shift succeeded - update input array
|
||||
for i := numKeep + discard; i < inputLen; i++ {
|
||||
slot.Inputs[i-discard] = slot.Inputs[i]
|
||||
}
|
||||
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
|
||||
slot.Inputs = slot.Inputs[:inputLen-discard]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -389,7 +389,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
var reprocess *ErrReprocessInputs
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Continue processing as normal
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break
|
||||
@@ -599,7 +607,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -611,6 +619,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -626,6 +635,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -691,7 +701,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting embeddings request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -703,6 +713,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -715,6 +726,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -118,6 +118,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
|
||||
}
|
||||
|
||||
if c.cache != nil {
|
||||
if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
|
||||
if err != nil {
|
||||
// Some models don't support partial erasure
|
||||
@@ -225,6 +229,8 @@ func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
||||
return count
|
||||
}
|
||||
|
||||
// TODO(jessegross): If we need to reprocess the inputs we should ensure that
|
||||
// we don't split up a SameBatch
|
||||
func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||
targetFree := (c.numCtx - numKeep) / 2
|
||||
targetFree = max(targetFree, 1)
|
||||
@@ -239,6 +245,14 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||
return discard
|
||||
}
|
||||
|
||||
type ErrReprocessInputs struct {
|
||||
Inputs []input.Input
|
||||
}
|
||||
|
||||
func (e *ErrReprocessInputs) Error() string {
|
||||
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
|
||||
}
|
||||
|
||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
@@ -258,11 +272,23 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||
"keep", numKeep, "discard", discard)
|
||||
|
||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||||
if c.cache != nil {
|
||||
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
|
||||
slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing",
|
||||
"id", slot.Id, "error", err)
|
||||
|
||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||
|
||||
// Reset the cache
|
||||
_ = c.cache.Remove(slot.Id, 0, -1)
|
||||
slot.Inputs = []input.Input{}
|
||||
|
||||
// Return error with inputs that need to be reprocessed
|
||||
return &ErrReprocessInputs{Inputs: newInputs}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
@@ -425,3 +428,92 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Mock implementation of the Cache interface
|
||||
type mockCache struct {
|
||||
shouldFail bool
|
||||
}
|
||||
|
||||
// Implement only the methods needed for the test
|
||||
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if m.shouldFail {
|
||||
return fmt.Errorf("mock cache removal error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stub implementations for other interface methods
|
||||
func (m *mockCache) SetLayer(layer int) {}
|
||||
func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
|
||||
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
|
||||
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
|
||||
func (m *mockCache) Close() {}
|
||||
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil }
|
||||
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
|
||||
func (m *mockCache) SetConfig(ml.CacheConfig) {}
|
||||
func (m *mockCache) CanResume(seq int, pos int32) bool { return true }
|
||||
|
||||
func TestShiftCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numCtx int32
|
||||
inputs []input.Input
|
||||
numKeep int32
|
||||
cacheErr bool
|
||||
wantErr any
|
||||
wantInputsLen int
|
||||
}{
|
||||
{
|
||||
name: "Normal shift",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: false, // No error
|
||||
wantErr: nil,
|
||||
wantInputsLen: 6, // After discarding 4 tokens
|
||||
},
|
||||
{
|
||||
name: "Cache removal fails",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: true,
|
||||
wantErr: &ErrReprocessInputs{},
|
||||
wantInputsLen: 0, // Original inputs should be cleared
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mock := &mockCache{shouldFail: tt.cacheErr}
|
||||
c := InputCache{
|
||||
numCtx: tt.numCtx,
|
||||
cache: mock,
|
||||
}
|
||||
slot := &InputCacheSlot{
|
||||
Id: 123,
|
||||
Inputs: make([]input.Input, len(tt.inputs)),
|
||||
}
|
||||
copy(slot.Inputs, tt.inputs)
|
||||
|
||||
err := c.ShiftCacheSlot(slot, tt.numKeep)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if !errors.As(err, &tt.wantErr) {
|
||||
t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err)
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(slot.Inputs) != tt.wantInputsLen {
|
||||
t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,16 +115,41 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
params.numKeep = int32(len(inputs))
|
||||
}
|
||||
|
||||
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
|
||||
// otherwise we might truncate or split the batch against the model's wishes
|
||||
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
if int32(len(inputs)) > s.cache.numCtx {
|
||||
discard := int32(len(inputs)) - s.cache.numCtx
|
||||
promptStart := params.numKeep + discard
|
||||
|
||||
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
|
||||
sameBatch := 0
|
||||
for i, inp := range inputs {
|
||||
if sameBatch > 0 {
|
||||
sameBatch--
|
||||
|
||||
if promptStart == int32(i) {
|
||||
promptStart++
|
||||
}
|
||||
} else if promptStart == int32(i) {
|
||||
break
|
||||
}
|
||||
|
||||
if inp.SameBatch != 0 {
|
||||
if int32(i) < params.numKeep {
|
||||
return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch)
|
||||
}
|
||||
|
||||
sameBatch = inp.SameBatch
|
||||
}
|
||||
}
|
||||
|
||||
if promptStart >= int32(len(inputs)) {
|
||||
return nil, errors.New("entire prompt removed by truncation")
|
||||
}
|
||||
|
||||
newInputs := inputs[:params.numKeep]
|
||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||
newInputs = append(newInputs, inputs[promptStart:]...)
|
||||
|
||||
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
|
||||
inputs = newInputs
|
||||
@@ -267,6 +292,9 @@ type Server struct {
|
||||
// KV cache
|
||||
cache *InputCache
|
||||
|
||||
// next sequence for prompt processing to avoid starvation
|
||||
nextSeq int
|
||||
|
||||
// multimodalHash generates hashes for comparing equality
|
||||
// of non-text data
|
||||
multimodalHash maphash.Hash
|
||||
@@ -351,14 +379,19 @@ func (s *Server) processBatch() error {
|
||||
var batchInputs []int32
|
||||
var batch input.Batch
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
resumeSeq := -1
|
||||
seqIdx := s.nextSeq - 1
|
||||
for range s.seqs {
|
||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||
seq := s.seqs[seqIdx]
|
||||
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(i, "limit")
|
||||
s.removeSequence(seqIdx, "limit")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -369,16 +402,23 @@ func (s *Server) processBatch() error {
|
||||
|
||||
batchSize := s.batchSize
|
||||
|
||||
for j, inp := range seq.inputs {
|
||||
for i, inp := range seq.inputs {
|
||||
// If we are required to put following inputs into a single batch then extend the
|
||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
||||
// will cause a break if we have pending inputs.
|
||||
// will cause a break if we have existing inputs.
|
||||
minBatch := 1 + inp.SameBatch
|
||||
if minBatch > batchSize {
|
||||
batchSize = minBatch
|
||||
}
|
||||
|
||||
if len(seq.pendingInputs)+minBatch > batchSize {
|
||||
// Stop if the required batch would put us over the total batch size (including tokens
|
||||
// added by other sequences). If we haven't been able to add anything yet then pick up
|
||||
// here again for the next batch to avoid starvation, though we can opportunistically
|
||||
// check if other sequences can still squeeze something in.
|
||||
if len(batchInputs)+minBatch > batchSize {
|
||||
if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
|
||||
resumeSeq = seqIdx
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
@@ -392,7 +432,15 @@ func (s *Server) processBatch() error {
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
var reprocess *ErrReprocessInputs
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Skip this sequence but continue processing the rest
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -405,7 +453,7 @@ func (s *Server) processBatch() error {
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if j+1 == len(seq.inputs) {
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
@@ -414,6 +462,12 @@ func (s *Server) processBatch() error {
|
||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||
}
|
||||
|
||||
if resumeSeq != -1 {
|
||||
s.nextSeq = resumeSeq
|
||||
} else {
|
||||
s.nextSeq = seqIdx + 1
|
||||
}
|
||||
|
||||
if len(batchInputs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -468,20 +522,6 @@ func (s *Server) processBatch() error {
|
||||
return fmt.Errorf("failed to sample token: %w", err)
|
||||
}
|
||||
|
||||
if seq.sampler.JSONSampler != nil {
|
||||
_, err = seq.sampler.JSONSampler.UpdateState([]int32{token})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if seq.sampler.PythonSampler != nil {
|
||||
err = seq.sampler.PythonSampler.UpdateState(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
@@ -576,21 +616,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// jsonSampler, err := sample.NewJSONSampler(s.model.(model.TextProcessor), nil)
|
||||
// if err != nil {
|
||||
// http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||
// return
|
||||
// }
|
||||
// jsonSampler = nil
|
||||
pythonSampler := &sample.PythonSampler{}
|
||||
functions := []sample.PythonFunction{
|
||||
{
|
||||
Name: "add_two_strings",
|
||||
Args: []string{"s1", "s2"},
|
||||
Types: []string{"string", "string"},
|
||||
},
|
||||
}
|
||||
pythonSampler.Init(functions, s.model.(model.TextProcessor))
|
||||
sampler := sample.NewSampler(
|
||||
req.Options.Temperature,
|
||||
req.Options.TopK,
|
||||
@@ -598,9 +623,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
req.Options.MinP,
|
||||
req.Options.Seed,
|
||||
grammar,
|
||||
nil,
|
||||
pythonSampler,
|
||||
// nil,
|
||||
)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
@@ -620,7 +642,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -632,6 +654,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -645,6 +668,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
package sample
|
||||
|
||||
var DefaultGrammar = map[string]string{
|
||||
"unicode": `\x{hex}{2} | \u{hex}{4} | \U{hex}{8}`,
|
||||
"null": `"null"`,
|
||||
"object": `"{" (kv ("," kv)*)? "}"`,
|
||||
"array": `"[" (value ("," value)*)? "]"`,
|
||||
"kv": `string ":" value`,
|
||||
"integer": `"0" | [1-9] [0-9]*`,
|
||||
"number": `"-"? integer frac? exp?`,
|
||||
"frac": `"." [0-9]+`,
|
||||
"exp": `("e" | "E") ("+" | "-") [0-9]+`,
|
||||
"string": `"\"" char* "\""`,
|
||||
"escape": `["/" | "b" | "f" | "n" | "r" | "t" | unicode]`,
|
||||
"char": `[^"\\] | escape`,
|
||||
"space": `(" " | "\t" | "\n" | "\r")*`,
|
||||
"hex": `[0-9] | [a-f] | [A-F]`,
|
||||
"boolean": `"true" | "false"`,
|
||||
"value": `object | array | string | number | boolean | "null"`,
|
||||
}
|
||||
|
||||
const jsonString = `object | array`
|
||||
|
||||
type StateMachine struct {
|
||||
states map[rune]State
|
||||
}
|
||||
|
||||
type State struct {
|
||||
NextStates []string
|
||||
// bitmask?
|
||||
Mask []bool
|
||||
IsTerminal bool
|
||||
}
|
||||
|
||||
func NewStateMachine(grammar map[string]string, startRule string) *StateMachine {
|
||||
states := make(map[rune]State)
|
||||
|
||||
var cumu string
|
||||
flag := false
|
||||
for _, r := range startRule {
|
||||
if r == '"' {
|
||||
flag = !flag
|
||||
}
|
||||
if flag {
|
||||
cumu += string(r)
|
||||
}
|
||||
}
|
||||
|
||||
sm := &StateMachine{
|
||||
states: states,
|
||||
}
|
||||
return sm
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGrammarParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
grammar map[string]string
|
||||
startRule string
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "simple object",
|
||||
grammar: map[string]string{
|
||||
"object": `"{" "}"`,
|
||||
},
|
||||
startRule: "object",
|
||||
input: "{}",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "simple array",
|
||||
grammar: map[string]string{
|
||||
"array": `"[" "]"`,
|
||||
},
|
||||
startRule: "array",
|
||||
input: "[]",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "character class",
|
||||
grammar: map[string]string{
|
||||
"digit": `[0-9]`,
|
||||
},
|
||||
startRule: "digit",
|
||||
input: "5",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "alternation",
|
||||
grammar: map[string]string{
|
||||
"bool": `"true" | "false"`,
|
||||
},
|
||||
startRule: "bool",
|
||||
input: "true",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "repetition",
|
||||
grammar: map[string]string{
|
||||
"digits": `[0-9]+`,
|
||||
},
|
||||
startRule: "digits",
|
||||
input: "123",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nested rules",
|
||||
grammar: map[string]string{
|
||||
"value": `object | array`,
|
||||
"object": `"{" "}"`,
|
||||
"array": `"[" "]"`,
|
||||
},
|
||||
startRule: "value",
|
||||
input: "{}",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := NewParser(tt.grammar)
|
||||
machine, err := parser.Parse(tt.startRule)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse() error = %v", err)
|
||||
}
|
||||
|
||||
matcher := NewMatcher(machine)
|
||||
got, err := matcher.Match(tt.input)
|
||||
if err != nil {
|
||||
t.Fatalf("Match() error = %v", err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("Match() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONGrammar(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"empty object", "{}", true},
|
||||
{"empty array", "[]", true},
|
||||
{"simple string", `"hello"`, true},
|
||||
{"simple number", "123", true},
|
||||
{"simple boolean", "true", true},
|
||||
{"simple null", "null", true},
|
||||
{"object with string", `{"key": "value"}`, true},
|
||||
{"array with numbers", "[1, 2, 3]", true},
|
||||
{"nested object", `{"obj": {"key": "value"}}`, true},
|
||||
{"nested array", `[1, [2, 3], 4]`, true},
|
||||
{"invalid object", "{", false},
|
||||
{"invalid array", "[1, 2", false},
|
||||
{"invalid string", `"hello`, false},
|
||||
}
|
||||
|
||||
parser := NewParser(DefaultGrammar)
|
||||
machine, err := parser.Parse("value")
|
||||
if err != nil {
|
||||
t.Fatalf("Parse() error = %v", err)
|
||||
}
|
||||
|
||||
matcher := NewMatcher(machine)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := matcher.Match(tt.input)
|
||||
if tt.want {
|
||||
if err != nil {
|
||||
t.Errorf("Match() error = %v", err)
|
||||
}
|
||||
if !got {
|
||||
t.Errorf("Match() = false, want true")
|
||||
}
|
||||
} else {
|
||||
if err == nil && got {
|
||||
t.Errorf("Match() = true, want false")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type JSONState int
|
||||
|
||||
const (
|
||||
StateStart JSONState = iota
|
||||
StateInObject
|
||||
StateInObjectKey
|
||||
StateInStructuredKey
|
||||
StateInStructuredValue
|
||||
StateNewline
|
||||
StateTab
|
||||
StateSpace
|
||||
StateInString
|
||||
StateInInt
|
||||
StateInFloat
|
||||
StateInBool
|
||||
StateInNull
|
||||
StateInColon
|
||||
StateInComma
|
||||
StateInTab
|
||||
StateInSpaceToValue
|
||||
StateInSpaceEndValue
|
||||
StateInNewlineEndValue
|
||||
StateInObjSpace
|
||||
StateInList
|
||||
StateInListComma
|
||||
StateInValue
|
||||
StateInValueEnd
|
||||
StateInListEnd
|
||||
StateInListObjectEnd
|
||||
StateInNewline
|
||||
StateInNumber
|
||||
StateInNumberEnd
|
||||
StateInStringEnd
|
||||
StateInObjectKeyEnd
|
||||
StateTerminate
|
||||
StateInObjectEnd
|
||||
StateTransitioningToTerminate
|
||||
StateInListStartJSON
|
||||
)
|
||||
|
||||
var JSONStates = []JSONState{
|
||||
StateStart,
|
||||
StateInObject,
|
||||
StateInObjectKey,
|
||||
StateInStructuredKey,
|
||||
StateInStructuredValue,
|
||||
StateNewline,
|
||||
StateTab,
|
||||
StateSpace,
|
||||
StateInString,
|
||||
StateInInt,
|
||||
StateInFloat,
|
||||
StateInBool,
|
||||
StateInNull,
|
||||
StateInColon,
|
||||
StateInComma,
|
||||
StateInTab,
|
||||
StateInSpaceToValue,
|
||||
StateInSpaceEndValue,
|
||||
StateInNewlineEndValue,
|
||||
StateInObjSpace,
|
||||
StateInListStartJSON,
|
||||
StateInList,
|
||||
StateInListComma,
|
||||
StateInValue,
|
||||
StateInValueEnd,
|
||||
StateInListEnd,
|
||||
StateInListObjectEnd,
|
||||
StateInNewline,
|
||||
StateInNumber,
|
||||
StateInNumberEnd,
|
||||
StateInStringEnd,
|
||||
StateInObjectKeyEnd,
|
||||
StateTerminate,
|
||||
StateInObjectEnd,
|
||||
StateTransitioningToTerminate,
|
||||
}
|
||||
|
||||
func (s JSONState) String() string {
|
||||
switch s {
|
||||
case StateStart:
|
||||
return "StateStart"
|
||||
case StateInObject:
|
||||
return "StateInObject"
|
||||
case StateInObjectKey:
|
||||
return "StateInObjectKey"
|
||||
case StateInStructuredKey:
|
||||
return "StateInStructuredKey"
|
||||
case StateInStructuredValue:
|
||||
return "StateInStructuredValue"
|
||||
case StateNewline:
|
||||
return "StateNewline"
|
||||
case StateTab:
|
||||
return "StateTab"
|
||||
case StateSpace:
|
||||
return "StateSpace"
|
||||
case StateInString:
|
||||
return "StateInString"
|
||||
case StateInInt:
|
||||
return "StateInInt"
|
||||
case StateInFloat:
|
||||
return "StateInFloat"
|
||||
case StateInBool:
|
||||
return "StateInBool"
|
||||
case StateInNull:
|
||||
return "StateInNull"
|
||||
case StateInColon:
|
||||
return "StateInColon"
|
||||
case StateInComma:
|
||||
return "StateInComma"
|
||||
case StateInTab:
|
||||
return "StateInTab"
|
||||
case StateInSpaceToValue:
|
||||
return "StateInSpaceToValue"
|
||||
case StateInSpaceEndValue:
|
||||
return "StateInSpaceEndValue"
|
||||
case StateInNewlineEndValue:
|
||||
return "StateInNewlineEndValue"
|
||||
case StateInObjSpace:
|
||||
return "StateInObjSpace"
|
||||
case StateInList:
|
||||
return "StateInList"
|
||||
case StateInListComma:
|
||||
return "StateInListComma"
|
||||
case StateInValue:
|
||||
return "StateInValue"
|
||||
case StateInValueEnd:
|
||||
return "StateInValueEnd"
|
||||
case StateInListEnd:
|
||||
return "StateInListEnd"
|
||||
case StateInListObjectEnd:
|
||||
return "StateInListObjectEnd"
|
||||
case StateInNewline:
|
||||
return "StateInNewline"
|
||||
case StateInNumber:
|
||||
return "StateInNumber"
|
||||
case StateInNumberEnd:
|
||||
return "StateInNumberEnd"
|
||||
case StateInStringEnd:
|
||||
return "StateInStringEnd"
|
||||
case StateInObjectKeyEnd:
|
||||
return "StateInObjectKeyEnd"
|
||||
case StateTerminate:
|
||||
return "StateTerminate"
|
||||
case StateInObjectEnd:
|
||||
return "StateInObjectEnd"
|
||||
case StateTransitioningToTerminate:
|
||||
return "StateTransitioningToTerminate"
|
||||
case StateInListStartJSON:
|
||||
return "StateInListStartJSON"
|
||||
default:
|
||||
return fmt.Sprintf("Unknown state: %d", s)
|
||||
}
|
||||
}
|
||||
@@ -1,327 +0,0 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
/*
|
||||
Key JSON rules to consider:
|
||||
|
||||
1. Whitespace handling:
|
||||
- Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
|
||||
- Current code only handles some whitespace cases
|
||||
|
||||
2. Number validation:
|
||||
- Need proper validation for special number cases like -0
|
||||
- Should handle .5 style decimals
|
||||
- Need limits on scientific notation (e, E)
|
||||
|
||||
3. String escaping:
|
||||
- Currently marks \ as invalid but should allow escaped sequences:
|
||||
- \"
|
||||
- \n
|
||||
- \u1234 unicode escapes
|
||||
|
||||
4. Empty object/array transitions:
|
||||
- Direct {} and [] cases could be more explicit
|
||||
- Need clear transitions for these edge cases
|
||||
|
||||
5. Nested depth limits:
|
||||
- No protection against excessive nesting
|
||||
- Could cause stack overflow with deeply nested structures
|
||||
*/
|
||||
|
||||
// TODO: / should be valid but an escape character
|
||||
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
|
||||
|
||||
var (
|
||||
intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||
validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||
)
|
||||
|
||||
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
|
||||
|
||||
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
|
||||
|
||||
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
|
||||
|
||||
type PDA struct {
|
||||
State JSONState
|
||||
TransitionEdges map[rune]*PDA
|
||||
MaskTokenIDToNode map[int32]*PDA
|
||||
}
|
||||
|
||||
func NewPDANode(state JSONState) *PDA {
|
||||
return &PDA{
|
||||
State: state,
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
}
|
||||
|
||||
type PDAGraphBuilder struct {
|
||||
proc model.TextProcessor
|
||||
decodedToks []string
|
||||
stateToNodeMap map[JSONState]*PDA
|
||||
tokenToStatesMap map[int32][]JSONState
|
||||
}
|
||||
|
||||
func (b *PDAGraphBuilder) BuildGraph() error {
|
||||
stateToNodeMap := make(map[JSONState]*PDA)
|
||||
for _, state := range JSONStates {
|
||||
stateToNodeMap[state] = NewPDANode(state)
|
||||
}
|
||||
|
||||
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
|
||||
|
||||
// TODO: update naming here - and revisit values
|
||||
stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
|
||||
|
||||
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||
stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
|
||||
// new line
|
||||
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
|
||||
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||
// stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
|
||||
// new line end value
|
||||
// stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
|
||||
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
// TODO: see if this is needed for formatting
|
||||
stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||
|
||||
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInTab].TransitionEdges['\t'] = stateToNodeMap[StateInNewline]
|
||||
|
||||
stateToNodeMap[StateInObjectKey].TransitionEdges[rune(-1)] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInObjectKey].TransitionEdges['"'] = stateToNodeMap[StateInObjectKeyEnd]
|
||||
|
||||
stateToNodeMap[StateInObjectKeyEnd].TransitionEdges[':'] = stateToNodeMap[StateInColon]
|
||||
|
||||
stateToNodeMap[StateInObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
stateToNodeMap[StateInObjectEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
|
||||
// where values should be
|
||||
// this could be combined but the probl might change, we're alr doing a skip ahead
|
||||
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||
stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
|
||||
|
||||
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
|
||||
|
||||
// Leads to a value
|
||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
|
||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
|
||||
|
||||
// Values
|
||||
// string node
|
||||
stateToNodeMap[StateInString].TransitionEdges[rune(-1)] = stateToNodeMap[StateInString]
|
||||
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
|
||||
|
||||
// String end node
|
||||
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
|
||||
// stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// TODO: add counters for allowable number of decimals, e, E, etc
|
||||
// number node
|
||||
for _, r := range validNumberRunes {
|
||||
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||
}
|
||||
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
|
||||
// stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// list node
|
||||
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||
// early end
|
||||
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
|
||||
// list end node
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
// stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// empty list
|
||||
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
|
||||
|
||||
// null node
|
||||
for _, r := range validNullRunes {
|
||||
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
|
||||
}
|
||||
addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
|
||||
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// list comma
|
||||
// should point to values
|
||||
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
|
||||
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList]
|
||||
|
||||
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
|
||||
|
||||
// list object end
|
||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
|
||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
// TODO: not sure if this is needed
|
||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// bool node
|
||||
for _, r := range validBoolRunes {
|
||||
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
||||
}
|
||||
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
||||
// stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// comma node
|
||||
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||
// todo: review this space transition
|
||||
// stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||
|
||||
// space end value
|
||||
// stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
b.stateToNodeMap = stateToNodeMap
|
||||
if err := b.preComputeValidStates(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) {
|
||||
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
}
|
||||
|
||||
func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
|
||||
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
|
||||
for _, r := range validNumberRunes {
|
||||
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||
}
|
||||
// TODO(parthsareen): force the output and shift similar to structured outputs
|
||||
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
|
||||
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
|
||||
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
||||
}
|
||||
|
||||
func (b *PDAGraphBuilder) preComputeValidStates() error {
|
||||
for _, node := range b.stateToNodeMap {
|
||||
// if node.State == StateInObjectKey {
|
||||
// if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 {
|
||||
// b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode
|
||||
// fmt.Println("copying string mask to object key mask")
|
||||
// }
|
||||
// }
|
||||
if err := b.CreateMask(node); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error {
|
||||
// TODO: make can be somewhere else too
|
||||
b.tokenToStatesMap = make(map[int32][]JSONState)
|
||||
for i, t := range b.decodedToks {
|
||||
for _, r := range t {
|
||||
if r == '"' {
|
||||
b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: the mask for obj key and string should be the same?
|
||||
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
|
||||
if node == nil {
|
||||
return fmt.Errorf("node cannot be nil")
|
||||
}
|
||||
for i := range b.decodedToks {
|
||||
token := b.decodedToks[i]
|
||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||
if b.proc.Is(int32(i), model.SpecialEOS) || b.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||
continue
|
||||
}
|
||||
curNode := node
|
||||
valid := true
|
||||
consumedSpecialRunes := make(map[rune]bool)
|
||||
for _, r := range token {
|
||||
curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||
if curNode == nil || !valid {
|
||||
break
|
||||
}
|
||||
}
|
||||
if valid {
|
||||
node.MaskTokenIDToNode[int32(i)] = curNode
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
|
||||
if consumedSpecialRunes[r] {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||
if specialRune {
|
||||
if curNode.State == StateInString || curNode.State == StateInObjectKey {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// Check for specific rune transition
|
||||
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||
// fmt.Println("next node", nextNode)
|
||||
if specialRune {
|
||||
if curNode.State == nextNode.State {
|
||||
return nil, false
|
||||
}
|
||||
consumedSpecialRunes[r] = true
|
||||
}
|
||||
return nextNode, true
|
||||
}
|
||||
|
||||
// Check for sentinel value - if present, any rune is valid
|
||||
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||
return nextNode, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
@@ -1,264 +0,0 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
// TODO: safety in case of invalid json
|
||||
// TODO: partial JSON matching?
|
||||
// TODO: interfaces to cleanup with return values
|
||||
// TODO this interface shouldn't be the sampler - should just use Sampler
|
||||
// TODO: add penalties for string \n stuff
|
||||
// TODO: minimize number of fwd passes if there is only one match
|
||||
// TODO: greedy sample initially and then backtrack if no match
|
||||
|
||||
type PushdownSampler struct {
|
||||
PDAGraphBuilder
|
||||
curNode *PDA
|
||||
braceStack []rune
|
||||
stateCounter uint32
|
||||
}
|
||||
|
||||
// graph should be built once and reused per tokenizer
|
||||
func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
|
||||
start := time.Now()
|
||||
|
||||
fmt.Println("--------------------------------")
|
||||
fmt.Println("PDA sampler")
|
||||
fmt.Println("--------------------------------")
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
before := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||
|
||||
vocab := proc.Vocab()
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
for i := range vocab.Values {
|
||||
token, err := proc.Decode([]int32{int32(i)})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
|
||||
gb := &PDAGraphBuilder{
|
||||
proc: proc,
|
||||
decodedToks: decodedToks,
|
||||
}
|
||||
|
||||
if err := gb.BuildGraph(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
runtime.ReadMemStats(&m)
|
||||
after := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||
|
||||
// TODO: this can be simplified
|
||||
return &PushdownSampler{
|
||||
curNode: gb.stateToNodeMap[StateStart],
|
||||
PDAGraphBuilder: *gb,
|
||||
braceStack: []rune{},
|
||||
stateCounter: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TODO: need to add resampling logic if the first sample was not good
|
||||
// greedy sample + backtrack?
|
||||
func (s *PushdownSampler) Apply(logits []float32) ([]float32, error) {
|
||||
switch s.curNode.State {
|
||||
case StateInString:
|
||||
return s.maskLogits(logits, s.curNode)
|
||||
|
||||
case StateInListEnd:
|
||||
// force finish if no braces left
|
||||
if len(s.braceStack) == 0 {
|
||||
s.curNode = NewPDANode(StateTerminate)
|
||||
return forceFinish(s, logits)
|
||||
}
|
||||
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateTerminate:
|
||||
return forceFinish(s, logits)
|
||||
|
||||
case StateInObjectEnd:
|
||||
// force finish if no braces left
|
||||
if len(s.braceStack) == 0 {
|
||||
s.curNode = NewPDANode(StateTerminate)
|
||||
return forceFinish(s, logits)
|
||||
}
|
||||
|
||||
peek := s.braceStack[len(s.braceStack)-1]
|
||||
if peek == rune('[') {
|
||||
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
|
||||
}
|
||||
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateInComma:
|
||||
peek := s.braceStack[len(s.braceStack)-1]
|
||||
if peek == rune('[') {
|
||||
s.curNode = s.stateToNodeMap[StateInListComma]
|
||||
}
|
||||
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
default:
|
||||
fmt.Println("masking logits current state", s.curNode.State)
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
}
|
||||
|
||||
func forceFinish(s *PushdownSampler, logits []float32) ([]float32, error) {
|
||||
for i := range logits {
|
||||
if s.proc.Is(int32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = float32(math.Inf(-1))
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
|
||||
fmt.Println("current state - updating", s.curNode.State)
|
||||
mappedString, err := s.proc.Decode(tokenSlice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf(">>> mappedString: %q\n", mappedString)
|
||||
|
||||
// Special handling for EOS token in terminate state
|
||||
if s.curNode.State == StateTerminate {
|
||||
for _, tokenID := range tokenSlice {
|
||||
if s.proc.Is(tokenID, model.SpecialEOS) {
|
||||
return tokenSlice, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flag := -1
|
||||
// endBraceRunes := []rune{'}', ']'}
|
||||
for _, r := range mappedString {
|
||||
// TODO: if this is enabled again, make sure to appropriately handle the state transitions
|
||||
// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
|
||||
// fmt.Printf("stack is empty, extra closing brace %c\n", r)
|
||||
// // flag = i
|
||||
// break
|
||||
|
||||
// }
|
||||
if r == rune('{') {
|
||||
s.braceStack = append(s.braceStack, r)
|
||||
}
|
||||
if r == rune('[') {
|
||||
s.braceStack = append(s.braceStack, r)
|
||||
}
|
||||
if r == rune('}') {
|
||||
if len(s.braceStack) == 0 {
|
||||
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||
}
|
||||
top := s.braceStack[len(s.braceStack)-1]
|
||||
if top != rune('{') {
|
||||
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
}
|
||||
|
||||
if r == rune(']') {
|
||||
if len(s.braceStack) == 0 {
|
||||
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||
}
|
||||
top := s.braceStack[len(s.braceStack)-1]
|
||||
if top != rune('[') {
|
||||
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
}
|
||||
}
|
||||
|
||||
// if flag != -1 {
|
||||
// tokenSlice = tokenSlice[:flag]
|
||||
// }
|
||||
// fmt.Println("flag!", flag)
|
||||
for _, tokenID := range tokenSlice {
|
||||
// transition to the next node
|
||||
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid token: %q", mappedString)
|
||||
}
|
||||
fmt.Println("transitioning to", nextNode.State)
|
||||
|
||||
// TODO: add a penalty for staying in the same state too long
|
||||
if nextNode.State == s.curNode.State {
|
||||
s.stateCounter++
|
||||
} else {
|
||||
s.stateCounter = 0
|
||||
}
|
||||
s.curNode = nextNode
|
||||
fmt.Println("updated curNode state", s.curNode.State)
|
||||
}
|
||||
return tokenSlice, nil
|
||||
}
|
||||
|
||||
// greedy sample + backtrack?
|
||||
func (s *PushdownSampler) maskLogits(logits []float32, node *PDA) ([]float32, error) {
|
||||
// Create a new slice with same length as logits, initialized to -Inf
|
||||
maskedLogits := make([]float32, len(logits))
|
||||
for i := range maskedLogits {
|
||||
maskedLogits[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
// Only update values for valid token IDs from the mask map
|
||||
for tokenID := range node.MaskTokenIDToNode {
|
||||
if int(tokenID) < len(logits) {
|
||||
maskedLogits[tokenID] = logits[tokenID]
|
||||
}
|
||||
}
|
||||
|
||||
return maskedLogits, nil
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) fastMaskLogits(logits []float32, node *PDA) ([]float32, error) {
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
maxIndex := -1
|
||||
|
||||
// Find the maximum logit value among valid tokens
|
||||
for tokenID := range node.MaskTokenIDToNode {
|
||||
if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
|
||||
maxLogit = logits[tokenID]
|
||||
maxIndex = int(tokenID)
|
||||
}
|
||||
}
|
||||
|
||||
if maxIndex == -1 {
|
||||
return nil, fmt.Errorf("no valid tokens found in mask")
|
||||
}
|
||||
|
||||
logits[0] = float32(maxIndex)
|
||||
return logits, nil
|
||||
// return maxIndex, nil
|
||||
}
|
||||
@@ -17,14 +17,12 @@ type token struct {
|
||||
}
|
||||
|
||||
type Sampler struct {
|
||||
rng *rand.Rand
|
||||
topK int
|
||||
topP float32
|
||||
minP float32
|
||||
temperature float32
|
||||
grammar *Grammar
|
||||
JSONSampler *JSONSampler
|
||||
PythonSampler *PythonSampler
|
||||
rng *rand.Rand
|
||||
topK int
|
||||
topP float32
|
||||
minP float32
|
||||
temperature float32
|
||||
grammar *Grammar
|
||||
}
|
||||
|
||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
@@ -32,19 +30,6 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
return -1, errors.New("sample: no logits provided to sample")
|
||||
}
|
||||
|
||||
var err error
|
||||
if s.JSONSampler != nil {
|
||||
logits, err = s.JSONSampler.Apply(logits)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
if s.PythonSampler != nil {
|
||||
logits, err = s.PythonSampler.ApplyMask(logits)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
tokens := make([]token, len(logits))
|
||||
for i := range logits {
|
||||
tokens[i].id = int32(i)
|
||||
@@ -142,7 +127,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar, jsonSampler *JSONSampler, pythonSampler *PythonSampler) Sampler {
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
||||
var rng *rand.Rand
|
||||
if seed != -1 {
|
||||
// PCG requires two parameters: sequence and stream
|
||||
@@ -170,14 +155,12 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
||||
}
|
||||
|
||||
return Sampler{
|
||||
rng: rng,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
minP: minP,
|
||||
temperature: temperature,
|
||||
grammar: grammar,
|
||||
JSONSampler: jsonSampler,
|
||||
PythonSampler: pythonSampler,
|
||||
rng: rng,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
minP: minP,
|
||||
temperature: temperature,
|
||||
grammar: grammar,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,299 +0,0 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/grammar/jsonschema"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type JSONSampler struct {
|
||||
schema *jsonschema.Schema
|
||||
propIdx int
|
||||
propToNodeMap map[string]*PDA
|
||||
pdaSampler *PushdownSampler
|
||||
decodedToks []string
|
||||
}
|
||||
|
||||
func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) {
|
||||
slog.Info("NewJSONSampler", "schema", schema)
|
||||
if proc == nil {
|
||||
return nil, fmt.Errorf("TextProcessor cannot be nil")
|
||||
}
|
||||
|
||||
pdaSampler, err := NewPushdownSampler(proc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
|
||||
}
|
||||
|
||||
if schema == nil {
|
||||
return &JSONSampler{
|
||||
schema: nil,
|
||||
propIdx: -1,
|
||||
propToNodeMap: nil,
|
||||
pdaSampler: pdaSampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// fmt.Println("schema not nil")
|
||||
so := &JSONSampler{
|
||||
schema: schema,
|
||||
propIdx: -1,
|
||||
propToNodeMap: make(map[string]*PDA),
|
||||
pdaSampler: pdaSampler,
|
||||
}
|
||||
|
||||
so.schemaToGraph()
|
||||
|
||||
// Benchmark token decoding
|
||||
start := time.Now()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
before := m.Alloc
|
||||
|
||||
vocab := proc.Vocab()
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
for i := range vocab.Values {
|
||||
token, err := proc.Decode([]int32{int32(i)})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
so.decodedToks = decodedToks
|
||||
|
||||
runtime.ReadMemStats(&m)
|
||||
after := m.Alloc
|
||||
fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Token decode time = %v\n", time.Since(start))
|
||||
|
||||
fmt.Println("--------------------------------")
|
||||
fmt.Println("SOSampler")
|
||||
fmt.Println("--------------------------------")
|
||||
// Benchmark this section
|
||||
start = time.Now()
|
||||
runtime.ReadMemStats(&m)
|
||||
before = m.Alloc
|
||||
|
||||
// TODO: still messed up
|
||||
// TODO: recursion use case
|
||||
// key masks
|
||||
for _, prop := range so.schema.Properties {
|
||||
node := so.propToNodeMap[prop.Name]
|
||||
// propName -> node
|
||||
curState := node.State
|
||||
fromNode := node
|
||||
so.pdaSampler.CreateMask(fromNode)
|
||||
for curState == StateInStructuredKey {
|
||||
// there is only one edge
|
||||
for r, toNode := range fromNode.TransitionEdges {
|
||||
fmt.Println("rune", r, "edge", toNode.State)
|
||||
so.pdaSampler.CreateMask(toNode)
|
||||
fmt.Printf("created mask for %c\n", r)
|
||||
curState = toNode.State
|
||||
fmt.Println("next state", curState)
|
||||
// TODO: theres an extra gen for " right now
|
||||
fromNode = toNode
|
||||
}
|
||||
}
|
||||
|
||||
if curState != StateInColon {
|
||||
return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
|
||||
}
|
||||
|
||||
// so.pdaSampler.CreateMask(fromNode)
|
||||
|
||||
fromNode = fromNode.TransitionEdges[' ']
|
||||
|
||||
so.pdaSampler.CreateMask(fromNode)
|
||||
curState = fromNode.State
|
||||
for _, toNode := range fromNode.TransitionEdges {
|
||||
fmt.Println("toNode", toNode.State)
|
||||
}
|
||||
}
|
||||
|
||||
// runtime.ReadMemStats(&m)
|
||||
// after = m.Alloc
|
||||
// fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
// fmt.Printf("Mask creation time = %v\n", time.Since(start))
|
||||
// fmt.Println("--------------------------------")
|
||||
|
||||
return so, nil
|
||||
}
|
||||
|
||||
func (s *JSONSampler) schemaToGraph() {
|
||||
schemaType := s.schema.EffectiveType()
|
||||
switch schemaType {
|
||||
case "object":
|
||||
// TODO: see if we need to connect these to the JSON graph
|
||||
|
||||
// each prop is a key
|
||||
for _, prop := range s.schema.Properties {
|
||||
// name of key
|
||||
name := prop.Name
|
||||
keyNode := &PDA{
|
||||
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
|
||||
prevNode := keyNode
|
||||
for _, r := range name {
|
||||
runeNode := &PDA{
|
||||
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
// fmt.Println("runeNode created", runeNode.State)
|
||||
// fmt.Printf("runeNode created %c\n", r)
|
||||
|
||||
// since alloc on heap connections wil still map
|
||||
prevNode.TransitionEdges[r] = runeNode
|
||||
prevNode = runeNode
|
||||
}
|
||||
|
||||
// point to end of object key node after all chars are done
|
||||
// prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
|
||||
|
||||
// link to value node
|
||||
// Create a node for the end of the key (after the closing quote)
|
||||
stringEndNode := &PDA{
|
||||
State: StateInStructuredKey,
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
prevNode.TransitionEdges['"'] = stringEndNode
|
||||
prevNode = stringEndNode
|
||||
|
||||
// Add transition for colon after key
|
||||
colonNode := &PDA{
|
||||
State: StateInColon,
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
prevNode.TransitionEdges[':'] = colonNode
|
||||
prevNode = colonNode
|
||||
|
||||
// Add transition for space after colon
|
||||
spaceNode := &PDA{
|
||||
State: StateInSpaceToValue,
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
prevNode.TransitionEdges[' '] = spaceNode
|
||||
prevNode = spaceNode
|
||||
|
||||
value := prop.Type
|
||||
switch value {
|
||||
case "object":
|
||||
fmt.Println("object under key: ", name)
|
||||
case "array":
|
||||
fmt.Println("array under key: ", name)
|
||||
case "string":
|
||||
fmt.Println("string under key: ", name)
|
||||
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
|
||||
case "number":
|
||||
fmt.Println("number under key: ", name)
|
||||
for _, r := range validNumberRunes {
|
||||
prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
|
||||
}
|
||||
case "boolean":
|
||||
fmt.Println("boolean under key: ", name)
|
||||
prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
|
||||
prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
|
||||
prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
|
||||
}
|
||||
|
||||
// points to start of the key
|
||||
s.propToNodeMap[name] = keyNode
|
||||
fmt.Println("name", name, "keyNode", keyNode.State)
|
||||
}
|
||||
}
|
||||
// TODO: do values + recursion
|
||||
}
|
||||
|
||||
func (s *JSONSampler) Apply(logits []float32) ([]float32, error) {
|
||||
if s.schema == nil {
|
||||
return s.pdaSampler.Apply(logits)
|
||||
}
|
||||
|
||||
switch s.pdaSampler.curNode.State {
|
||||
// TODO: doesnt account for multi rune case
|
||||
case StateInObjectKey:
|
||||
if s.propIdx > len(s.schema.Properties)-1 {
|
||||
return nil, fmt.Errorf("propIdx out of bounds")
|
||||
}
|
||||
// fmt.Println("in object key - structured outputs")
|
||||
// TODO: this tracking should probably be coming from a stack to track nested objects
|
||||
// simple case
|
||||
s.propIdx++
|
||||
fmt.Println("propIdx", s.propIdx)
|
||||
prop := s.schema.Properties[s.propIdx]
|
||||
fmt.Println("prop", prop.Name)
|
||||
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
||||
fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
|
||||
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
default:
|
||||
|
||||
// Will only happen for the last prop - can also be precomputed.
|
||||
if s.propIdx == len(s.schema.Properties)-1 {
|
||||
// todo: if i incremenet propidx then i know im in last value as well
|
||||
switch s.pdaSampler.curNode.State {
|
||||
case StateInObjectEnd:
|
||||
fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
|
||||
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
|
||||
s.pdaSampler.curNode = NewPDANode(StateTerminate)
|
||||
s.propIdx++
|
||||
|
||||
// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
|
||||
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
|
||||
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
|
||||
delete(s.pdaSampler.curNode.TransitionEdges, ',')
|
||||
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
|
||||
|
||||
s.pdaSampler.CreateMask(s.pdaSampler.curNode)
|
||||
s.propIdx++
|
||||
}
|
||||
}
|
||||
return s.pdaSampler.Apply(logits)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
|
||||
tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.schema == nil {
|
||||
// Don't need to update state for unconstrained JSON sampling
|
||||
return tokenSlice, nil
|
||||
}
|
||||
|
||||
switch s.pdaSampler.curNode.State {
|
||||
case StateInObjectKey:
|
||||
s.propIdx++
|
||||
fmt.Println("propIdx", s.propIdx)
|
||||
prop := s.schema.Properties[s.propIdx]
|
||||
fmt.Println("prop", prop.Name)
|
||||
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
||||
// TODO: this does not work - mike
|
||||
// str, err := s.pdaSampler.proc.Decode(tokenSlice)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// fmt.Println("str", str)
|
||||
|
||||
return tokenSlice, nil
|
||||
default:
|
||||
return tokenSlice, nil
|
||||
}
|
||||
}
|
||||
@@ -1,352 +0,0 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type PythonState int
|
||||
|
||||
const (
|
||||
PythonStateStart PythonState = iota
|
||||
StateInFunction
|
||||
StateInFunctionArgs
|
||||
StateInFunctionArgsType
|
||||
StateInFunctionEnd
|
||||
PStateInString
|
||||
PStateInStringEnd
|
||||
PStateInNumber
|
||||
PStateInList
|
||||
PStateInListEnd
|
||||
PStateInDict
|
||||
PStateInDictEnd
|
||||
PStateInTuple
|
||||
PStateInTupleEnd
|
||||
PStateTerminate
|
||||
)
|
||||
|
||||
func (s PythonState) String() string {
|
||||
switch s {
|
||||
case PythonStateStart:
|
||||
return "PythonStateStart"
|
||||
case StateInFunction:
|
||||
return "StateInFunction"
|
||||
case StateInFunctionArgs:
|
||||
return "StateInFunctionArgs"
|
||||
case StateInFunctionArgsType:
|
||||
return "StateInFunctionArgsType"
|
||||
case StateInFunctionEnd:
|
||||
return "StateInFunctionEnd"
|
||||
case PStateInString:
|
||||
return "PStateInString"
|
||||
case PStateInStringEnd:
|
||||
return "PStateInStringEnd"
|
||||
case PStateInNumber:
|
||||
return "PStateInNumber"
|
||||
case PStateInList:
|
||||
return "PStateInList"
|
||||
case PStateInListEnd:
|
||||
return "PStateInListEnd"
|
||||
case PStateInDict:
|
||||
return "PStateInDict"
|
||||
case PStateInDictEnd:
|
||||
return "PStateInDictEnd"
|
||||
case PStateInTuple:
|
||||
return "PStateInTuple"
|
||||
case PStateInTupleEnd:
|
||||
return "PStateInTupleEnd"
|
||||
case PStateTerminate:
|
||||
return "PStateTerminate"
|
||||
default:
|
||||
return fmt.Sprintf("PythonState(%d)", s)
|
||||
}
|
||||
}
|
||||
|
||||
var PythonStates = []PythonState{
|
||||
PythonStateStart,
|
||||
StateInFunction,
|
||||
StateInFunctionArgs,
|
||||
StateInFunctionArgsType,
|
||||
StateInFunctionEnd,
|
||||
PStateInString,
|
||||
PStateInStringEnd,
|
||||
PStateInNumber,
|
||||
PStateInList,
|
||||
PStateInListEnd,
|
||||
PStateInDict,
|
||||
PStateInDictEnd,
|
||||
PStateInTuple,
|
||||
PStateInTupleEnd,
|
||||
PStateTerminate,
|
||||
}
|
||||
|
||||
type Node struct {
|
||||
State PythonState
|
||||
TransitionEdges map[rune]*Node
|
||||
MaskTokenIDToNode map[int32]*Node
|
||||
}
|
||||
|
||||
func NewNode(state PythonState) *Node {
|
||||
return &Node{
|
||||
State: state,
|
||||
TransitionEdges: make(map[rune]*Node),
|
||||
MaskTokenIDToNode: make(map[int32]*Node),
|
||||
}
|
||||
}
|
||||
|
||||
type PythonFunction struct {
|
||||
Name string
|
||||
Args []string
|
||||
Types []string
|
||||
}
|
||||
|
||||
type PythonSampler struct {
|
||||
stateToNodes map[PythonState]*Node
|
||||
proc model.TextProcessor
|
||||
decodedToks []string
|
||||
curNode *Node
|
||||
completed int
|
||||
functions []PythonFunction
|
||||
}
|
||||
|
||||
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
|
||||
s.proc = proc
|
||||
s.functions = functions
|
||||
decodedToks := make([]string, len(proc.Vocab().Values))
|
||||
for i := range proc.Vocab().Values {
|
||||
token, err := proc.Decode([]int32{int32(i)})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
s.decodedToks = decodedToks
|
||||
s.BuildGraph()
|
||||
for _, function := range functions {
|
||||
prevNode := s.stateToNodes[PythonStateStart]
|
||||
|
||||
for _, r := range function.Name {
|
||||
nextNode := NewNode(StateInFunction)
|
||||
prevNode.TransitionEdges[r] = nextNode
|
||||
if err := s.CreateMask(nextNode); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("prevNode", prevNode.State)
|
||||
fmt.Printf("transition edge: %q\n", r)
|
||||
fmt.Println("nextNode", nextNode.State)
|
||||
prevNode = nextNode
|
||||
}
|
||||
prevNode.TransitionEdges['('] = s.stateToNodes[StateInFunctionArgs]
|
||||
s.CreateMask(prevNode)
|
||||
prevNode = s.stateToNodes[StateInFunctionArgs]
|
||||
for i, arg := range function.Args {
|
||||
for _, r := range arg {
|
||||
nextNode := NewNode(StateInFunctionArgs)
|
||||
prevNode.TransitionEdges[r] = nextNode
|
||||
s.CreateMask(prevNode)
|
||||
prevNode = nextNode
|
||||
}
|
||||
prevNode.TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||
// prevNode = s.stateToNodes[StateInFunctionArgs]
|
||||
prevNode.TransitionEdges['='] = NewNode(StateInFunctionArgsType)
|
||||
s.CreateMask(prevNode)
|
||||
prevNode = prevNode.TransitionEdges['=']
|
||||
switch function.Types[i] {
|
||||
case "string":
|
||||
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInString]
|
||||
s.CreateMask(prevNode.TransitionEdges['"'])
|
||||
case "number":
|
||||
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInNumber]
|
||||
s.CreateMask(prevNode.TransitionEdges['"'])
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
s.curNode = s.stateToNodes[PythonStateStart]
|
||||
fmt.Println("curNode", s.curNode.State)
|
||||
fmt.Println("transition edges", s.curNode.TransitionEdges)
|
||||
if err := s.CreateMask(s.curNode); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("maskTokenIDToNode", s.curNode.MaskTokenIDToNode)
|
||||
for tokenID, node := range s.curNode.MaskTokenIDToNode {
|
||||
fmt.Printf("tokenID: %d, node: %v\n", s.decodedToks[tokenID], node.State)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PythonSampler) BuildGraph() error {
|
||||
s.stateToNodes = make(map[PythonState]*Node)
|
||||
for _, state := range PythonStates {
|
||||
s.stateToNodes[state] = NewNode(state)
|
||||
}
|
||||
|
||||
for _, state := range s.stateToNodes {
|
||||
if err := s.CreateMask(state); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// String
|
||||
s.stateToNodes[PStateInString].TransitionEdges[rune(-1)] = s.stateToNodes[PStateInString]
|
||||
s.stateToNodes[PStateInString].TransitionEdges['"'] = s.stateToNodes[PStateInStringEnd]
|
||||
|
||||
// String end
|
||||
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||
// s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||
// Number
|
||||
for _, r := range validNumberRunes {
|
||||
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
|
||||
}
|
||||
s.stateToNodes[PStateInNumber].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||
s.stateToNodes[PStateInNumber].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||
s.stateToNodes[PStateInNumber].TransitionEdges[' '] = s.stateToNodes[StateInFunctionArgs]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PythonSampler) ApplyMask(logits []float32) ([]float32, error) {
|
||||
if s.curNode.State == PStateTerminate {
|
||||
logits, err := finish(s, logits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
func (s *PythonSampler) UpdateState(token int32) error {
|
||||
mappedString, err := s.proc.Decode([]int32{token})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(">>> mappedString: %q\n", mappedString)
|
||||
|
||||
if s.curNode.State == PStateTerminate {
|
||||
if s.proc.Is(token, model.SpecialEOS) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
nextNode, ok := s.curNode.MaskTokenIDToNode[token]
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid token: %q", mappedString)
|
||||
}
|
||||
|
||||
if mappedString == "\"" {
|
||||
if s.curNode.State == PStateInStringEnd {
|
||||
s.completed++
|
||||
}
|
||||
if s.completed == len(s.functions) {
|
||||
s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||
s.CreateMask(s.curNode)
|
||||
}
|
||||
}
|
||||
s.curNode = nextNode
|
||||
fmt.Println("curNode", s.curNode.State)
|
||||
for r, node := range s.curNode.TransitionEdges {
|
||||
fmt.Printf("transition edge: %q -> %v\n", r, node.State)
|
||||
}
|
||||
if err := s.CreateMask(s.curNode); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PythonSampler) CreateMask(node *Node) error {
|
||||
if node == nil {
|
||||
return fmt.Errorf("node cannot be nil")
|
||||
}
|
||||
for i := range s.decodedToks {
|
||||
token := s.decodedToks[i]
|
||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||
if s.proc.Is(int32(i), model.SpecialEOS) || s.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||
continue
|
||||
}
|
||||
curNode := node
|
||||
valid := true
|
||||
consumedSpecialRunes := make(map[rune]bool)
|
||||
for _, r := range token {
|
||||
curNode, valid = isRValid(r, curNode, consumedSpecialRunes)
|
||||
if curNode == nil || !valid {
|
||||
break
|
||||
}
|
||||
}
|
||||
if valid {
|
||||
if curNode.State == StateInFunction {
|
||||
// fmt.Println("cm curNode", curNode.State)
|
||||
// fmt.Println("cm token", s.decodedToks[i])
|
||||
}
|
||||
node.MaskTokenIDToNode[int32(i)] = curNode
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRValid(r rune, curNode *Node, consumedSpecialRunes map[rune]bool) (*Node, bool) {
|
||||
if consumedSpecialRunes[r] {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||
if specialRune {
|
||||
if curNode.State == PStateInString || curNode.State == PStateInStringEnd {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// Check for specific rune transition
|
||||
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||
// fmt.Println("next node", nextNode)
|
||||
if specialRune {
|
||||
if curNode.State == nextNode.State {
|
||||
return nil, false
|
||||
}
|
||||
consumedSpecialRunes[r] = true
|
||||
}
|
||||
return nextNode, true
|
||||
}
|
||||
|
||||
// Check for sentinel value - if present, any rune is valid
|
||||
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||
return nextNode, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *PythonSampler) maskLogits(logits []float32, node *Node) ([]float32, error) {
|
||||
// Create a new slice with same length as logits, initialized to -Inf
|
||||
maskedLogits := make([]float32, len(logits))
|
||||
for i := range maskedLogits {
|
||||
maskedLogits[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
// Only update values for valid token IDs from the mask map
|
||||
for tokenID := range node.MaskTokenIDToNode {
|
||||
if int(tokenID) < len(logits) {
|
||||
maskedLogits[tokenID] = logits[tokenID]
|
||||
}
|
||||
}
|
||||
|
||||
return maskedLogits, nil
|
||||
}
|
||||
|
||||
func finish(s *PythonSampler, logits []float32) ([]float32, error) {
|
||||
for i := range logits {
|
||||
if s.proc.Is(int32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = float32(math.Inf(-1))
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
@@ -29,8 +29,9 @@ import (
|
||||
const maxRetries = 6
|
||||
|
||||
var (
|
||||
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
errPartStalled = errors.New("part stalled")
|
||||
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
errPartStalled = errors.New("part stalled")
|
||||
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
|
||||
)
|
||||
|
||||
var blobDownloadManager sync.Map
|
||||
@@ -236,7 +237,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
|
||||
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) > 10 {
|
||||
return errors.New("maximum redirects exceeded (10) for directURL")
|
||||
return errMaxRedirectsExceeded
|
||||
}
|
||||
|
||||
// if the hostname is the same, allow the redirect
|
||||
|
||||
108
server/images.go
108
server/images.go
@@ -35,14 +35,9 @@ var (
|
||||
errCapabilityCompletion = errors.New("completion")
|
||||
errCapabilityTools = errors.New("tools")
|
||||
errCapabilityInsert = errors.New("insert")
|
||||
)
|
||||
|
||||
type Capability string
|
||||
|
||||
const (
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
errCapabilityVision = errors.New("vision")
|
||||
errCapabilityEmbedding = errors.New("embedding")
|
||||
errInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
type registryOptions struct {
|
||||
@@ -65,52 +60,83 @@ type Model struct {
|
||||
System string
|
||||
License []string
|
||||
Digest string
|
||||
Options map[string]interface{}
|
||||
Options map[string]any
|
||||
Messages []api.Message
|
||||
|
||||
Template *template.Template
|
||||
}
|
||||
|
||||
// Capabilities returns the capabilities that the model supports
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for completion capability
|
||||
r, err := os.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
defer r.Close()
|
||||
|
||||
f, _, err := ggml.Decode(r, 0)
|
||||
if err == nil {
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
} else {
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
} else {
|
||||
slog.Error("couldn't decode ggml", "error", err)
|
||||
}
|
||||
} else {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
}
|
||||
|
||||
if m.Template == nil {
|
||||
return capabilities
|
||||
}
|
||||
|
||||
// Check for tools capability
|
||||
if slices.Contains(m.Template.Vars(), "tools") {
|
||||
capabilities = append(capabilities, model.CapabilityTools)
|
||||
}
|
||||
|
||||
// Check for insert capability
|
||||
if slices.Contains(m.Template.Vars(), "suffix") {
|
||||
capabilities = append(capabilities, model.CapabilityInsert)
|
||||
}
|
||||
|
||||
return capabilities
|
||||
}
|
||||
|
||||
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
||||
// any missing or unknown capabilities
|
||||
func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||
func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
available := m.Capabilities()
|
||||
var errs []error
|
||||
for _, cap := range caps {
|
||||
switch cap {
|
||||
case CapabilityCompletion:
|
||||
r, err := os.Open(m.ModelPath)
|
||||
if err != nil {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
continue
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
|
||||
f, _, err := ggml.Decode(r, 0)
|
||||
if err != nil {
|
||||
slog.Error("couldn't decode ggml", "error", err)
|
||||
continue
|
||||
}
|
||||
// Map capabilities to their corresponding error
|
||||
capToErr := map[model.Capability]error{
|
||||
model.CapabilityCompletion: errCapabilityCompletion,
|
||||
model.CapabilityTools: errCapabilityTools,
|
||||
model.CapabilityInsert: errCapabilityInsert,
|
||||
model.CapabilityVision: errCapabilityVision,
|
||||
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||
}
|
||||
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||
errs = append(errs, errCapabilityCompletion)
|
||||
}
|
||||
case CapabilityTools:
|
||||
if !slices.Contains(m.Template.Vars(), "tools") {
|
||||
errs = append(errs, errCapabilityTools)
|
||||
}
|
||||
case CapabilityInsert:
|
||||
vars := m.Template.Vars()
|
||||
if !slices.Contains(vars, "suffix") {
|
||||
errs = append(errs, errCapabilityInsert)
|
||||
}
|
||||
default:
|
||||
for _, cap := range want {
|
||||
err, ok := capToErr[cap]
|
||||
if !ok {
|
||||
slog.Error("unknown capability", "capability", cap)
|
||||
return fmt.Errorf("unknown capability: %s", cap)
|
||||
}
|
||||
|
||||
if !slices.Contains(available, cap) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := errors.Join(errs...); err != nil {
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||
}
|
||||
|
||||
@@ -479,7 +505,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errors.New("insecure protocol http")
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
manifest, _, err := GetManifest(mp)
|
||||
@@ -543,7 +569,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errors.New("insecure protocol http")
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
360
server/images_test.go
Normal file
360
server/images_test.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Constants for GGUF magic bytes and version
|
||||
var (
|
||||
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
|
||||
ggufVer = uint32(3) // Version 3
|
||||
)
|
||||
|
||||
// Helper function to create mock GGUF data
|
||||
func createMockGGUFData(architecture string, vision bool) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Write GGUF header
|
||||
buf.Write(ggufMagic)
|
||||
binary.Write(&buf, binary.LittleEndian, ggufVer)
|
||||
|
||||
// Write tensor count (0 for our test)
|
||||
var numTensors uint64 = 0
|
||||
binary.Write(&buf, binary.LittleEndian, numTensors)
|
||||
|
||||
// Calculate number of metadata entries
|
||||
numMetaEntries := uint64(1) // architecture entry
|
||||
if vision {
|
||||
numMetaEntries++
|
||||
}
|
||||
// Add embedding entry if architecture is "bert"
|
||||
if architecture == "bert" {
|
||||
numMetaEntries++
|
||||
}
|
||||
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
|
||||
|
||||
// Write architecture metadata
|
||||
archKey := "general.architecture"
|
||||
keyLen := uint64(len(archKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(archKey)
|
||||
|
||||
// String type (8)
|
||||
var strType uint32 = 8
|
||||
binary.Write(&buf, binary.LittleEndian, strType)
|
||||
|
||||
// String length
|
||||
strLen := uint64(len(architecture))
|
||||
binary.Write(&buf, binary.LittleEndian, strLen)
|
||||
buf.WriteString(architecture)
|
||||
|
||||
if vision {
|
||||
visionKey := architecture + ".vision.block_count"
|
||||
keyLen = uint64(len(visionKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(visionKey)
|
||||
|
||||
// uint32 type (4)
|
||||
var uint32Type uint32 = 4
|
||||
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||
|
||||
// uint32 value (1)
|
||||
var countVal uint32 = 1
|
||||
binary.Write(&buf, binary.LittleEndian, countVal)
|
||||
}
|
||||
// Write embedding metadata if architecture is "bert"
|
||||
if architecture == "bert" {
|
||||
poolKey := architecture + ".pooling_type"
|
||||
keyLen = uint64(len(poolKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(poolKey)
|
||||
|
||||
// uint32 type (4)
|
||||
var uint32Type uint32 = 4
|
||||
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||
|
||||
// uint32 value (1)
|
||||
var poolingVal uint32 = 1
|
||||
binary.Write(&buf, binary.LittleEndian, poolingVal)
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestModelCapabilities(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tempDir, err := os.MkdirTemp("", "model_capabilities_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create different types of mock model files
|
||||
completionModelPath := filepath.Join(tempDir, "model.bin")
|
||||
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||
// Create a simple model file for tests that don't depend on GGUF content
|
||||
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
|
||||
|
||||
err = os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create completion model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create completion model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create simple model file: %v", err)
|
||||
}
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
testModels := []struct {
|
||||
name string
|
||||
model Model
|
||||
expectedCaps []model.Capability
|
||||
}{
|
||||
{
|
||||
name: "model with completion capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion},
|
||||
},
|
||||
|
||||
{
|
||||
name: "model with completion, tools, and insert capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with tools and insert capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with tools capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityTools},
|
||||
},
|
||||
{
|
||||
name: "model with vision capability",
|
||||
model: Model{
|
||||
ModelPath: visionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with vision, tools, and insert capability",
|
||||
model: Model{
|
||||
ModelPath: visionModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with embedding capability",
|
||||
model: Model{
|
||||
ModelPath: embeddingModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
||||
},
|
||||
}
|
||||
|
||||
// compare two slices of model.Capability regardless of order
|
||||
compareCapabilities := func(a, b []model.Capability) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
aCount := make(map[model.Capability]int)
|
||||
for _, cap := range a {
|
||||
aCount[cap]++
|
||||
}
|
||||
|
||||
bCount := make(map[model.Capability]int)
|
||||
for _, cap := range b {
|
||||
bCount[cap]++
|
||||
}
|
||||
|
||||
for cap, count := range aCount {
|
||||
if bCount[cap] != count {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
for _, tt := range testModels {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test Capabilities method
|
||||
caps := tt.model.Capabilities()
|
||||
if !compareCapabilities(caps, tt.expectedCaps) {
|
||||
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelCheckCapabilities(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tempDir, err := os.MkdirTemp("", "model_check_capabilities_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||
simpleModelPath := filepath.Join(tempDir, "model.bin")
|
||||
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||
|
||||
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create simple model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vision model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||
}
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model Model
|
||||
checkCaps []model.Capability
|
||||
expectedErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "completion model without tools capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityTools},
|
||||
expectedErrMsg: "does not support tools",
|
||||
},
|
||||
{
|
||||
name: "model with all needed capabilities",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model missing insert capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityInsert},
|
||||
expectedErrMsg: "does not support insert",
|
||||
},
|
||||
{
|
||||
name: "model missing vision capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityVision},
|
||||
expectedErrMsg: "does not support vision",
|
||||
},
|
||||
{
|
||||
name: "model with vision capability",
|
||||
model: Model{
|
||||
ModelPath: visionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with embedding capability",
|
||||
model: Model{
|
||||
ModelPath: embeddingModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityEmbedding},
|
||||
},
|
||||
{
|
||||
name: "unknown capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{"unknown"},
|
||||
expectedErrMsg: "unknown capability",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test CheckCapabilities method
|
||||
err := tt.model.CheckCapabilities(tt.checkCaps...)
|
||||
if tt.expectedErrMsg == "" {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing %q, got nil", tt.expectedErrMsg)
|
||||
} else if !strings.Contains(err.Error(), tt.expectedErrMsg) {
|
||||
t.Errorf("Expected error containing %q, got: %v", tt.expectedErrMsg, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -421,14 +421,6 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func canRetry(err error) bool {
|
||||
var re *Error
|
||||
if !errors.As(err, &re) {
|
||||
return false
|
||||
}
|
||||
return re.Status >= 500
|
||||
}
|
||||
|
||||
// trackingReader is an io.Reader that tracks the number of bytes read and
|
||||
// calls the update function with the layer, the number of bytes read.
|
||||
//
|
||||
@@ -514,13 +506,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
break
|
||||
}
|
||||
|
||||
cacheKey := fmt.Sprintf(
|
||||
"v1 pull chunksum %s %s %d-%d",
|
||||
l.Digest,
|
||||
cs.Digest,
|
||||
cs.Chunk.Start,
|
||||
cs.Chunk.End,
|
||||
)
|
||||
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
|
||||
_, err := c.Get(cacheKeyDigest)
|
||||
if err == nil {
|
||||
received.Add(cs.Chunk.Size())
|
||||
t.update(l, cs.Chunk.Size(), ErrCached)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
// Ignore cache key write errors for now. We've already
|
||||
// reported to trace that the chunk is complete.
|
||||
//
|
||||
// Ideally, we should only report completion to trace
|
||||
// after successful cache commit. This current approach
|
||||
// works but could trigger unnecessary redownloads if
|
||||
// the checkpoint key is missing on next pull.
|
||||
//
|
||||
// Not incorrect, just suboptimal - fix this in a
|
||||
// future update.
|
||||
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
|
||||
|
||||
received.Add(cs.Chunk.Size())
|
||||
} else {
|
||||
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
||||
t.update(l, 0, err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
@@ -563,7 +582,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
return err
|
||||
}
|
||||
if received.Load() != expected {
|
||||
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
|
||||
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected)
|
||||
}
|
||||
|
||||
md := blob.DigestFromBytes(m.Data)
|
||||
@@ -608,6 +627,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manifest) All() iter.Seq[*Layer] {
|
||||
return func(yield func(*Layer) bool) {
|
||||
if !yield(m.Config) {
|
||||
return
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
if !yield(l) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manifest) Size() int64 {
|
||||
var size int64
|
||||
if m.Config != nil {
|
||||
size += m.Config.Size
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
size += l.Size
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
//
|
||||
// NOTE: It adds an empty config object to the manifest, which is required by
|
||||
@@ -750,20 +793,32 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
||||
return
|
||||
}
|
||||
|
||||
// A chunksums response is a sequence of chunksums in a
|
||||
// simple, easy to parse line-oriented format.
|
||||
// The response is a sequence of chunksums.
|
||||
//
|
||||
// Example:
|
||||
// Chunksums are chunks of a larger blob that can be
|
||||
// downloaded and verified independently.
|
||||
//
|
||||
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
// The chunksums endpoint is a GET request that returns a
|
||||
// sequence of chunksums in the following format:
|
||||
//
|
||||
// << HTTP/1.1 200 OK
|
||||
// << Content-Location: <blobURL>
|
||||
// <<
|
||||
// << <digest> <start>-<end>
|
||||
// << ...
|
||||
// > GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
//
|
||||
// The blobURL is the URL to download the chunks from.
|
||||
// < HTTP/1.1 200 OK
|
||||
// < Content-Location: <blobURL>
|
||||
// <
|
||||
// < <digest> <start>-<end>
|
||||
// < ...
|
||||
//
|
||||
// The <blobURL> is the URL to download the chunks from and
|
||||
// each <digest> is the digest of the chunk, and <start>-<end>
|
||||
// is the range the chunk in the blob.
|
||||
//
|
||||
// Ranges may be used directly in Range headers like
|
||||
// "bytes=<start>-<end>".
|
||||
//
|
||||
// The chunksums returned are guaranteed to be contiguous and
|
||||
// include all bytes of the layer. If the stream is cut short,
|
||||
// clients should retry.
|
||||
|
||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
||||
scheme,
|
||||
|
||||
@@ -9,17 +9,14 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
@@ -338,15 +335,8 @@ func TestPushCommitRoundtripError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func checkNotExist(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Fatalf("err = %v; want fs.ErrNotExist", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullInvalidName(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
err := rc.Pull(t.Context(), "://")
|
||||
if !errors.Is(err, ErrNameInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||
@@ -362,197 +352,16 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, resp := range cases {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, resp)
|
||||
})
|
||||
err := rc.Pull(t.Context(), "x")
|
||||
err := rc.Pull(t.Context(), "http://example.com/a/b")
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullNotCached(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
var c *blob.DiskCache
|
||||
var rc *Registry
|
||||
|
||||
d := blob.DigestFromBytes("some data")
|
||||
rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
io.WriteString(w, "some data")
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d)
|
||||
})
|
||||
|
||||
// Confirm that the layer does not exist locally
|
||||
_, err := rc.ResolveLocal("model")
|
||||
checkNotExist(t, err)
|
||||
|
||||
_, err = c.Get(d)
|
||||
checkNotExist(t, err)
|
||||
|
||||
err = rc.Pull(t.Context(), "model")
|
||||
check(err)
|
||||
|
||||
mw, err := rc.Resolve(t.Context(), "model")
|
||||
check(err)
|
||||
mg, err := rc.ResolveLocal("model")
|
||||
check(err)
|
||||
if !reflect.DeepEqual(mw, mg) {
|
||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||
}
|
||||
|
||||
// Confirm successful download
|
||||
info, err := c.Get(d)
|
||||
check(err)
|
||||
if info.Digest != d {
|
||||
t.Errorf("info.Digest = %v; want %v", info.Digest, d)
|
||||
}
|
||||
if info.Size != 9 {
|
||||
t.Errorf("info.Size = %v; want %v", info.Size, 9)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(c.GetFile(d))
|
||||
check(err)
|
||||
if string(data) != "some data" {
|
||||
t.Errorf("data = %q; want %q", data, "exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullCached(t *testing.T) {
|
||||
cached := blob.DigestFromBytes("exists")
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(499) // should not be called
|
||||
return
|
||||
}
|
||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached)
|
||||
}
|
||||
})
|
||||
|
||||
var errs []error
|
||||
var reads []int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(d *Layer, n int64, err error) {
|
||||
t.Logf("update %v %d %v", d, n, err)
|
||||
reads = append(reads, n)
|
||||
errs = append(errs, err)
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := rc.Pull(ctx, "single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{0, 6}
|
||||
if !errors.Is(errors.Join(errs...), ErrCached) {
|
||||
t.Errorf("errs = %v; want %v", errs, ErrCached)
|
||||
}
|
||||
if !slices.Equal(reads, want) {
|
||||
t.Errorf("pairs = %v; want %v", reads, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullManifestNotFound(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
err := rc.Pull(t.Context(), "notfound")
|
||||
checkErrCode(t, err, 404, "")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
||||
})
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
checkErrCode(t, err, 500, "an_error")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
return
|
||||
}
|
||||
})
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistryPullMixedCachedNotCached tests that cached layers do not
|
||||
// interfere with pulling layers that are not cached
|
||||
func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
x := blob.DigestFromBytes("xxxxxx")
|
||||
e := blob.DigestFromBytes("exists")
|
||||
y := blob.DigestFromBytes("yyyyyy")
|
||||
|
||||
for i := range 10 {
|
||||
t.Logf("iteration %d", i)
|
||||
|
||||
digests := []blob.Digest{x, e, y}
|
||||
|
||||
rand.Shuffle(len(digests), func(i, j int) {
|
||||
digests[i], digests[j] = digests[j], digests[i]
|
||||
})
|
||||
|
||||
manifest := fmt.Sprintf(`{
|
||||
"layers": [
|
||||
{"digest":"%s","size":6},
|
||||
{"digest":"%s","size":6},
|
||||
{"digest":"%s","size":6}
|
||||
]
|
||||
}`, digests[0], digests[1], digests[2])
|
||||
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch path.Base(r.URL.Path) {
|
||||
case "latest":
|
||||
io.WriteString(w, manifest)
|
||||
case x.String():
|
||||
io.WriteString(w, "xxxxxx")
|
||||
case e.String():
|
||||
io.WriteString(w, "exists")
|
||||
case y.String():
|
||||
io.WriteString(w, "yyyyyy")
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected request: %v", r))
|
||||
}
|
||||
})
|
||||
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Logf("update %v %d %v", l, n, err)
|
||||
},
|
||||
})
|
||||
|
||||
// Check that we pull all layers that we can.
|
||||
|
||||
err := rc.Pull(ctx, "mixed")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, d := range digests {
|
||||
info, err := c.Get(d)
|
||||
if err != nil {
|
||||
t.Fatalf("Get(%v): %v", d, err)
|
||||
}
|
||||
if info.Size != 6 {
|
||||
t.Errorf("info.Size = %v; want %v", info.Size, 6)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryResolveByDigest(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
@@ -590,26 +399,6 @@ func TestInsecureSkipVerify(t *testing.T) {
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestCanRetry(t *testing.T) {
|
||||
cases := []struct {
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{nil, false},
|
||||
{errors.New("x"), false},
|
||||
{ErrCached, false},
|
||||
{ErrManifestInvalid, false},
|
||||
{ErrNameInvalid, false},
|
||||
{&Error{Status: 100}, false},
|
||||
{&Error{Status: 500}, true},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
if got := canRetry(tt.err); got != tt.want {
|
||||
t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorUnmarshal(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
@@ -761,17 +550,23 @@ func TestParseNameExtended(t *testing.T) {
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
t.Run("found by name", func(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
check := testutil.Checker(t)
|
||||
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
// make a blob and link it
|
||||
d := blob.DigestFromBytes("{}")
|
||||
err := blob.PutBytes(rc.Cache, d, "{}")
|
||||
check(err)
|
||||
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
|
||||
check(err)
|
||||
|
||||
// confirm linked
|
||||
_, err := rc.ResolveLocal("single")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
_, err = rc.ResolveLocal("single")
|
||||
check(err)
|
||||
|
||||
// unlink
|
||||
_, err = rc.Unlink("single")
|
||||
testutil.Check(t, err)
|
||||
check(err)
|
||||
|
||||
// confirm unlinked
|
||||
_, err = rc.ResolveLocal("single")
|
||||
@@ -780,7 +575,7 @@ func TestUnlink(t *testing.T) {
|
||||
}
|
||||
})
|
||||
t.Run("not found by name", func(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
ok, err := rc.Unlink("manifestNotFound")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -791,78 +586,368 @@ func TestUnlink(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPullChunksums(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
// Many tests from here out, in this file are based on a single blob, "abc",
|
||||
// with the checksum of its sha256 hash. The checksum is:
|
||||
//
|
||||
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
|
||||
//
|
||||
// Using the literal value instead of a constant with fmt.Xprintf calls proved
|
||||
// to be the most readable and maintainable approach. The sum is consistently
|
||||
// used in the tests and unique so searches do not yield false positives.
|
||||
|
||||
content := "hello"
|
||||
var chunksums string
|
||||
contentDigest := func() blob.Digest {
|
||||
return blob.DigestFromBytes(content)
|
||||
func checkRequest(t *testing.T, req *http.Request, method, path string) {
|
||||
t.Helper()
|
||||
if got := req.URL.Path; got != path {
|
||||
t.Errorf("URL = %q, want %q", got, path)
|
||||
}
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/manifests/latest"):
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
|
||||
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
|
||||
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
|
||||
w.Header().Set("Content-Location", loc)
|
||||
io.WriteString(w, chunksums)
|
||||
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
|
||||
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
|
||||
default:
|
||||
t.Errorf("unexpected request: %v", r)
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
if req.Method != method {
|
||||
t.Errorf("Method = %q, want %q", req.Method, method)
|
||||
}
|
||||
}
|
||||
|
||||
rc.MaxStreams = 1 // prevent concurrent chunk downloads
|
||||
rc.ChunkingThreshold = 1 // for all blobs to be chunked
|
||||
func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) {
|
||||
s := httptest.NewServer(h)
|
||||
t.Cleanup(s.Close)
|
||||
cache, err := blob.Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var reads []int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Logf("Update: %v %d %v", l, n, err)
|
||||
mu.Lock()
|
||||
reads = append(reads, n)
|
||||
mu.Unlock()
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
},
|
||||
})
|
||||
|
||||
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
|
||||
blob.DigestFromBytes("hel"),
|
||||
blob.DigestFromBytes("lo"),
|
||||
)
|
||||
err := rc.Pull(ctx, "test")
|
||||
check(err)
|
||||
wantReads := []int64{
|
||||
0, // initial signaling of layer pull starting
|
||||
3, // first chunk read
|
||||
2, // second chunk read
|
||||
}
|
||||
if !slices.Equal(reads, wantReads) {
|
||||
t.Errorf("reads = %v; want %v", reads, wantReads)
|
||||
rc := &Registry{
|
||||
Cache: cache,
|
||||
HTTPClient: &http.Client{Transport: &http.Transport{
|
||||
Dial: func(network, addr string) (net.Conn, error) {
|
||||
return net.Dial(network, s.Listener.Addr().String())
|
||||
},
|
||||
}},
|
||||
}
|
||||
return rc, ctx
|
||||
}
|
||||
|
||||
mw, err := rc.Resolve(t.Context(), "test")
|
||||
check(err)
|
||||
mg, err := rc.ResolveLocal("test")
|
||||
check(err)
|
||||
if !reflect.DeepEqual(mw, mg) {
|
||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||
}
|
||||
for i := range mg.Layers {
|
||||
_, err = c.Get(mg.Layers[i].Digest)
|
||||
if err != nil {
|
||||
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
|
||||
func TestPullChunked(t *testing.T) {
|
||||
var steps atomic.Int64
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch steps.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3, 4:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
switch rng := r.Header.Get("Range"); rng {
|
||||
case "bytes=0-1":
|
||||
io.WriteString(w, "ab")
|
||||
case "bytes=2-2":
|
||||
t.Logf("writing c")
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected range %q", rng)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// missing chunks
|
||||
content = "llama"
|
||||
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
|
||||
err = rc.Pull(ctx, "missingchunks")
|
||||
if err == nil {
|
||||
t.Error("expected error because of missing chunks")
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
testutil.Check(t, err)
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
testutil.Check(t, err)
|
||||
|
||||
if g := steps.Load(); g != 4 {
|
||||
t.Fatalf("got %d steps, want 4", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullCached(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
})
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
// Premeptively cache the blob
|
||||
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
check(err)
|
||||
err = blob.PutBytes(c.Cache, d, []byte("abc"))
|
||||
check(err)
|
||||
|
||||
// Pull only the manifest, which should be enough to resolve the cached blob
|
||||
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestPullManifestError(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
var got *Error
|
||||
if !errors.Is(err, ErrModelNotFound) {
|
||||
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullLayerError(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `!`)
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
var want *json.SyntaxError
|
||||
if !errors.As(err, &want) {
|
||||
t.Fatalf("err = %T, want %T", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullLayerChecksumError(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
|
||||
case 4:
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.MaxStreams = 1
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
var written atomic.Int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
written.Add(n)
|
||||
},
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
var got *Error
|
||||
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
|
||||
t.Fatalf("err = %v, want %v", err, got)
|
||||
}
|
||||
|
||||
if g := written.Load(); g != 1 {
|
||||
t.Fatalf("wrote %d bytes, want 1", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullChunksumStreamError(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
|
||||
// Write one valid chunksum and one invalid chunksum
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
|
||||
fmt.Fprint(w, "sha256:!") // invalid
|
||||
case 3:
|
||||
io.WriteString(w, "ab")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
got := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if !errors.Is(got, ErrIncomplete) {
|
||||
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
|
||||
}
|
||||
}
|
||||
|
||||
type flushAfterWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = f.w.Write(p)
|
||||
f.w.(http.Flusher).Flush() // panic if not a flusher
|
||||
return
|
||||
}
|
||||
|
||||
func TestPullChunksumStreaming(t *testing.T) {
|
||||
csr, csw := io.Pipe()
|
||||
defer csw.Close()
|
||||
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
|
||||
_, err := io.Copy(fw, csr)
|
||||
if err != nil {
|
||||
t.Errorf("copy: %v", err)
|
||||
}
|
||||
case 3:
|
||||
io.WriteString(w, "ab")
|
||||
case 4:
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
update := make(chan int64, 1)
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
if n > 0 {
|
||||
update <- n
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
errc <- c.Pull(ctx, "http://o.com/library/abc")
|
||||
}()
|
||||
|
||||
// Send first chunksum and ensure it kicks off work immediately
|
||||
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
if g := <-update; g != 2 {
|
||||
t.Fatalf("got %d, want 2", g)
|
||||
}
|
||||
|
||||
// now send the second chunksum and ensure it kicks off work immediately
|
||||
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
if g := <-update; g != 1 {
|
||||
t.Fatalf("got %d, want 1", g)
|
||||
}
|
||||
csw.Close()
|
||||
testutil.Check(t, <-errc)
|
||||
}
|
||||
|
||||
func TestPullChunksumsCached(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3, 4:
|
||||
switch rng := r.Header.Get("Range"); rng {
|
||||
case "bytes=0-1":
|
||||
io.WriteString(w, "ab")
|
||||
case "bytes=2-2":
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected range %q", rng)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.MaxStreams = 1 // force serial processing of chunksums
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
// Cancel the pull after the first chunksum is processed, but before
|
||||
// the second chunksum is processed (which is waiting because
|
||||
// MaxStreams=1). This should cause the second chunksum to error out
|
||||
// leaving the blob incomplete.
|
||||
ctx = WithTrace(ctx, &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
if n > 0 {
|
||||
cancel()
|
||||
}
|
||||
},
|
||||
})
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("err = %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
|
||||
// Reset state and pull again to ensure the blob chunks that should
|
||||
// have been cached are, and the remaining chunk was downloaded, making
|
||||
// the blob complete.
|
||||
step.Store(0)
|
||||
var written atomic.Int64
|
||||
var cached atomic.Int64
|
||||
ctx = WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
if errors.Is(err, ErrCached) {
|
||||
cached.Add(n)
|
||||
}
|
||||
written.Add(n)
|
||||
},
|
||||
})
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||
check(err)
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
check(err)
|
||||
|
||||
if g := written.Load(); g != 3 {
|
||||
t.Fatalf("wrote %d bytes, want 3", g)
|
||||
}
|
||||
if g := cached.Load(); g != 2 { // "ab" should have been cached
|
||||
t.Fatalf("cached %d bytes, want 3", g)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,9 +31,10 @@ const (
|
||||
|
||||
var (
|
||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
ErrModelPathInvalid = errors.New("invalid model path")
|
||||
)
|
||||
|
||||
func ParseModelPath(name string) ModelPath {
|
||||
@@ -73,8 +74,6 @@ func ParseModelPath(name string) ModelPath {
|
||||
return mp
|
||||
}
|
||||
|
||||
var errModelPathInvalid = errors.New("invalid model path")
|
||||
|
||||
func (mp ModelPath) GetNamespaceRepository() string {
|
||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ var (
|
||||
errBadTemplate = errors.New("template error")
|
||||
)
|
||||
|
||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||
func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
|
||||
opts := api.DefaultOptions()
|
||||
if err := opts.FromMap(model.Options); err != nil {
|
||||
return api.Options{}, err
|
||||
@@ -87,7 +87,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
||||
|
||||
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
||||
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
||||
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||
if name == "" {
|
||||
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
||||
}
|
||||
@@ -144,7 +144,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
model, err := GetModel(name.String())
|
||||
m, err := GetModel(name.String())
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, fs.ErrNotExist):
|
||||
@@ -159,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
// expire the runner
|
||||
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||
s.sched.expireRunner(model)
|
||||
s.sched.expireRunner(m)
|
||||
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
@@ -176,9 +176,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
caps := []Capability{CapabilityCompletion}
|
||||
caps := []model.Capability{model.CapabilityCompletion}
|
||||
if req.Suffix != "" {
|
||||
caps = append(caps, CapabilityInsert)
|
||||
caps = append(caps, model.CapabilityInsert)
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
@@ -203,7 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
isMllama := checkMllamaModelFamily(model)
|
||||
isMllama := checkMllamaModelFamily(m)
|
||||
if isMllama && len(req.Images) > 1 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
|
||||
return
|
||||
@@ -211,7 +211,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
images := make([]llm.ImageData, len(req.Images))
|
||||
for i := range req.Images {
|
||||
if isMllama && len(model.ProjectorPaths) > 0 {
|
||||
if isMllama && len(m.ProjectorPaths) > 0 {
|
||||
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
|
||||
@@ -422,7 +422,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
@@ -530,7 +530,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
@@ -777,7 +777,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
return nil, errModelPathInvalid
|
||||
return nil, ErrModelPathInvalid
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
@@ -813,19 +813,20 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
|
||||
resp := &api.ShowResponse{
|
||||
License: strings.Join(m.License, "\n"),
|
||||
System: m.System,
|
||||
Template: m.Template.String(),
|
||||
Details: modelDetails,
|
||||
Messages: msgs,
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
License: strings.Join(m.License, "\n"),
|
||||
System: m.System,
|
||||
Template: m.Template.String(),
|
||||
Details: modelDetails,
|
||||
Messages: msgs,
|
||||
Capabilities: m.Capabilities(),
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
}
|
||||
|
||||
var params []string
|
||||
cs := 30
|
||||
for k, v := range m.Options {
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
case []any:
|
||||
for _, nv := range val {
|
||||
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
|
||||
}
|
||||
@@ -1335,7 +1336,7 @@ func Serve(ln net.Listener) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitForStream(c *gin.Context, ch chan interface{}) {
|
||||
func waitForStream(c *gin.Context, ch chan any) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
for resp := range ch {
|
||||
switch r := resp.(type) {
|
||||
@@ -1468,9 +1469,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
caps := []Capability{CapabilityCompletion}
|
||||
caps := []model.Capability{model.CapabilityCompletion}
|
||||
if len(req.Tools) > 0 {
|
||||
caps = append(caps, CapabilityTools)
|
||||
caps = append(caps, model.CapabilityTools)
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -37,7 +38,7 @@ type Scheduler struct {
|
||||
pendingReqCh chan *LlmRequest
|
||||
finishedReqCh chan *LlmRequest
|
||||
expiredCh chan *runnerRef
|
||||
unloadedCh chan interface{}
|
||||
unloadedCh chan any
|
||||
|
||||
loaded map[string]*runnerRef
|
||||
loadedMu sync.Mutex
|
||||
@@ -67,7 +68,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
||||
pendingReqCh: make(chan *LlmRequest, maxQueue),
|
||||
finishedReqCh: make(chan *LlmRequest, maxQueue),
|
||||
expiredCh: make(chan *runnerRef, maxQueue),
|
||||
unloadedCh: make(chan interface{}, maxQueue),
|
||||
unloadedCh: make(chan any, maxQueue),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: llm.NewLlamaServer,
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
@@ -195,7 +196,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
}
|
||||
|
||||
// Embedding models should always be loaded with parallel=1
|
||||
if pending.model.CheckCapabilities(CapabilityCompletion) != nil {
|
||||
if pending.model.CheckCapabilities(model.CapabilityCompletion) != nil {
|
||||
numParallel = 1
|
||||
}
|
||||
|
||||
@@ -617,8 +618,8 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
// a before and after GPU memory allocation. The returned channel
|
||||
// will be notified when we're done waiting, or have timed out and should
|
||||
// proceed anyway
|
||||
func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
|
||||
finished := make(chan interface{}, 1)
|
||||
func (runner *runnerRef) waitForVRAMRecovery() chan any {
|
||||
finished := make(chan any, 1)
|
||||
|
||||
// CPU or Metal don't need checking, so no waiting required
|
||||
// windows can page VRAM, only cuda currently can report accurate used vram usage
|
||||
@@ -711,7 +712,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
|
||||
req.opts.NumCtx = req.origNumCtx * p
|
||||
if !envconfig.SchedSpread() {
|
||||
for _, g := range sgl {
|
||||
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
||||
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
|
||||
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
|
||||
*numParallel = p
|
||||
return []discover.GpuInfo{g}
|
||||
@@ -727,7 +728,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
|
||||
// Now try all the GPUs
|
||||
for _, p := range numParallelToTry {
|
||||
req.opts.NumCtx = req.origNumCtx * p
|
||||
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
||||
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
|
||||
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
|
||||
*numParallel = p
|
||||
return sgl
|
||||
@@ -750,7 +751,7 @@ func pickBestPartialFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.Gp
|
||||
var bestEstimate uint64
|
||||
var bestFit int
|
||||
for i, gl := range byLibrary {
|
||||
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
|
||||
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, *numParallel)
|
||||
if estimatedVRAM > bestEstimate {
|
||||
bestEstimate = estimatedVRAM
|
||||
bestFit = i
|
||||
@@ -825,7 +826,7 @@ func (s *Scheduler) expireRunner(model *Model) {
|
||||
// If not, pick a runner to unload, else return nil and the request can be loaded
|
||||
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
|
||||
slog.Debug("evaluating if CPU model load will fit in available system memory")
|
||||
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts)
|
||||
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts, req.opts.NumCtx/req.origNumCtx)
|
||||
if estimate.TotalSize <= gpus[0].FreeMemory {
|
||||
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
|
||||
return nil
|
||||
|
||||
15
types/model/capability.go
Normal file
15
types/model/capability.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package model
|
||||
|
||||
type Capability string
|
||||
|
||||
const (
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
return string(c)
|
||||
}
|
||||
Reference in New Issue
Block a user