Compare commits
8 Commits
v0.5.2-rc2
...
v0.5.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60f75560a2 | ||
|
|
e28f2d4900 | ||
|
|
c216850523 | ||
|
|
18f6a98bd6 | ||
|
|
b1fd7fef86 | ||
|
|
36d111e788 | ||
|
|
9039c821a2 | ||
|
|
581a4a5553 |
4
.github/workflows/release.yaml
vendored
4
.github/workflows/release.yaml
vendored
@@ -447,15 +447,19 @@ jobs:
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cpu
|
||||
path: dist/windows-amd64/
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cuda-11.3
|
||||
path: dist/windows-amd64/
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cuda-12.4
|
||||
path: dist/windows-amd64/
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-rocm
|
||||
path: dist/windows-amd64/
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: windows-arm64
|
||||
|
||||
@@ -66,6 +66,7 @@ COPY . .
|
||||
ARG OLLAMA_SKIP_CUDA_GENERATE
|
||||
ARG OLLAMA_SKIP_ROCM_GENERATE
|
||||
ARG OLLAMA_FAST_BUILD
|
||||
ARG VERSION
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
if grep "^flags" /proc/cpuinfo|grep avx>/dev/null; then \
|
||||
make -j $(expr $(nproc) / 2 ) dist ; \
|
||||
@@ -91,6 +92,7 @@ WORKDIR /go/src/github.com/ollama/ollama/
|
||||
COPY . .
|
||||
ARG CGO_CFLAGS
|
||||
ENV GOARCH arm64
|
||||
ARG VERSION
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
make -j 5 dist_cuda_v11 \
|
||||
CUDA_ARCHITECTURES="72;87" \
|
||||
@@ -109,6 +111,7 @@ WORKDIR /go/src/github.com/ollama/ollama/
|
||||
COPY . .
|
||||
ARG CGO_CFLAGS
|
||||
ENV GOARCH arm64
|
||||
ARG VERSION
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
make -j 5 dist_cuda_v12 \
|
||||
CUDA_ARCHITECTURES="87" \
|
||||
@@ -120,6 +123,7 @@ FROM --platform=linux/arm64 unified-builder-arm64 AS build-arm64
|
||||
COPY . .
|
||||
ARG OLLAMA_SKIP_CUDA_GENERATE
|
||||
ARG OLLAMA_FAST_BUILD
|
||||
ARG VERSION
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
make -j 5 dist
|
||||
COPY --from=runners-jetpack5-arm64 /go/src/github.com/ollama/ollama/dist/ dist/
|
||||
|
||||
@@ -601,7 +601,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
|
||||
var data [][]string
|
||||
|
||||
for _, m := range models.Models {
|
||||
if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
|
||||
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
|
||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,6 +233,8 @@ curl http://localhost:11434/v1/embeddings \
|
||||
- [x] `seed`
|
||||
- [x] `stop`
|
||||
- [x] `stream`
|
||||
- [x] `stream_options`
|
||||
- [x] `include_usage`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `max_tokens`
|
||||
@@ -261,6 +263,8 @@ curl http://localhost:11434/v1/embeddings \
|
||||
- [x] `seed`
|
||||
- [x] `stop`
|
||||
- [x] `stream`
|
||||
- [x] `stream_options`
|
||||
- [x] `include_usage`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `max_tokens`
|
||||
|
||||
76
llama/grammar_test.go
Normal file
76
llama/grammar_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// https://github.com/ollama/ollama/issues/7978
|
||||
const issue7978JSONSchema = `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"steps": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"explanation": { "type": "string" },
|
||||
"output": { "type": "string" }
|
||||
},
|
||||
"required": ["explanation", "output"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"final_answer": { "type": "string" }
|
||||
},
|
||||
"required": ["steps", "final_answer"],
|
||||
"additionalProperties": false
|
||||
}`
|
||||
|
||||
func TestIssue7978(t *testing.T) {
|
||||
g := SchemaToGrammar([]byte(issue7978JSONSchema))
|
||||
if g == nil {
|
||||
t.Fatal("failed to convert JSON schema to grammar")
|
||||
}
|
||||
|
||||
t.Logf("grammar:\n%s", g)
|
||||
t.Log()
|
||||
|
||||
var sawSteps bool
|
||||
s := bufio.NewScanner(bytes.NewReader(g))
|
||||
for s.Scan() {
|
||||
line := s.Text()
|
||||
if strings.Contains(line, "steps") {
|
||||
sawSteps = true
|
||||
}
|
||||
if strings.Contains(line, "final-answer") && !sawSteps {
|
||||
t.Error("expected 'steps' before 'final-answer'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchemaToGrammer(t *testing.T) {
|
||||
cases := []struct {
|
||||
schema string
|
||||
prefix []byte // nil is check as nil
|
||||
}{
|
||||
{`invalid`, nil},
|
||||
|
||||
// Simple heuristic/smoke test
|
||||
{`{"type":"object"}`, []byte("root ::= object")},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("x", func(t *testing.T) {
|
||||
g := SchemaToGrammar([]byte(c.schema))
|
||||
if c.prefix == nil && g != nil {
|
||||
t.Fatalf("grammar = %v, want nil", g)
|
||||
}
|
||||
if !bytes.HasPrefix(g, c.prefix) {
|
||||
t.Errorf("grammar = %q, want %q", g, c.prefix)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
2
llama/json-schema-to-grammar.cpp
vendored
2
llama/json-schema-to-grammar.cpp
vendored
@@ -417,7 +417,7 @@ class SchemaConverter {
|
||||
private:
|
||||
std::function<json(const std::string &)> _fetch_json;
|
||||
bool _dotall;
|
||||
std::map<std::string, std::string> _rules;
|
||||
std::unordered_map<std::string, std::string> _rules;
|
||||
std::unordered_map<std::string, json> _refs;
|
||||
std::unordered_set<std::string> _refs_being_resolved;
|
||||
std::vector<std::string> _errors;
|
||||
|
||||
@@ -86,12 +86,10 @@ COMPILER inline get_compiler() {
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/cgo"
|
||||
"slices"
|
||||
@@ -134,7 +132,7 @@ func llamaLog(level int32, text *C.char, _ unsafe.Pointer) {
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Print(C.GoString(text))
|
||||
fmt.Fprint(os.Stderr, C.GoString(text))
|
||||
}
|
||||
|
||||
func GetModelArch(modelPath string) (string, error) {
|
||||
@@ -721,21 +719,10 @@ func (s *SamplingContext) Accept(id int, applyGrammar bool) {
|
||||
C.common_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
|
||||
}
|
||||
|
||||
type JsonSchema struct {
|
||||
Defs map[string]any `json:"$defs,omitempty"`
|
||||
Properties map[string]any `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
func (js JsonSchema) AsGrammar() string {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(js); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
cStr := C.CString(b.String())
|
||||
// SchemaToGrammar converts the provided JSON schema to a grammar. It returns
|
||||
// nil if the provided schema is invalid JSON or an invalid JSON schema.
|
||||
func SchemaToGrammar(schema []byte) []byte {
|
||||
cStr := C.CString(string(schema))
|
||||
defer C.free(unsafe.Pointer(cStr))
|
||||
|
||||
// Allocate buffer for grammar output with reasonable size
|
||||
@@ -743,10 +730,10 @@ func (js JsonSchema) AsGrammar() string {
|
||||
buf := make([]byte, maxLen)
|
||||
|
||||
// Call C function to convert schema to grammar
|
||||
length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
|
||||
if length == 0 {
|
||||
slog.Warn("unable to convert schema to grammar")
|
||||
n := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
|
||||
if n == 0 {
|
||||
// preserve nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return string(buf[:length])
|
||||
return buf[:n]
|
||||
}
|
||||
|
||||
@@ -1,70 +1 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestJsonSchema(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
schema JsonSchema
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty schema",
|
||||
schema: JsonSchema{
|
||||
Type: "object",
|
||||
},
|
||||
expected: `array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null`,
|
||||
},
|
||||
{
|
||||
name: "invalid schema with circular reference",
|
||||
schema: JsonSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"self": map[string]any{
|
||||
"$ref": "#", // Self reference
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "", // Should return empty string for invalid schema
|
||||
},
|
||||
{
|
||||
name: "schema with invalid type",
|
||||
schema: JsonSchema{
|
||||
Type: "invalid_type", // Invalid type
|
||||
Properties: map[string]any{
|
||||
"foo": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "", // Should return empty string for invalid schema
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := tc.schema.AsGrammar()
|
||||
if !strings.EqualFold(strings.TrimSpace(result), strings.TrimSpace(tc.expected)) {
|
||||
if diff := cmp.Diff(tc.expected, result); diff != "" {
|
||||
t.Fatalf("grammar mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: ParthSareen <parth.sareen@ollama.com>
|
||||
Date: Wed, 11 Dec 2024 15:37:32 -0800
|
||||
Subject: [PATCH] Maintain ordering for rules for grammar
|
||||
|
||||
---
|
||||
common/json-schema-to-grammar.cpp | 2 +-
|
||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||
|
||||
diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp
|
||||
index dadc18c8..2a8dbd22 100644
|
||||
--- a/common/json-schema-to-grammar.cpp
|
||||
+++ b/common/json-schema-to-grammar.cpp
|
||||
@@ -391,7 +391,7 @@ class SchemaConverter {
|
||||
private:
|
||||
std::function<json(const std::string &)> _fetch_json;
|
||||
bool _dotall;
|
||||
- std::map<std::string, std::string> _rules;
|
||||
+ std::unordered_map<std::string, std::string> _rules;
|
||||
std::unordered_map<std::string, json> _refs;
|
||||
std::unordered_set<std::string> _refs_being_resolved;
|
||||
std::vector<std::string> _errors;
|
||||
2
llama/sampling_ext.cpp
vendored
2
llama/sampling_ext.cpp
vendored
@@ -49,7 +49,7 @@ int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
|
||||
{
|
||||
try
|
||||
{
|
||||
nlohmann::json schema = nlohmann::json::parse(json_schema);
|
||||
nlohmann::ordered_json schema = nlohmann::ordered_json::parse(json_schema);
|
||||
std::string grammar_str = json_schema_to_grammar(schema);
|
||||
size_t len = grammar_str.length();
|
||||
if (len >= max_len)
|
||||
|
||||
@@ -610,7 +610,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
const jsonGrammar = `
|
||||
var grammarJSON = `
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
object ::=
|
||||
@@ -722,22 +722,19 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
// TODO (parthsareen): Move conversion to grammar with sampling logic
|
||||
// API should do error handling for invalid formats
|
||||
if req.Format != nil && strings.TrimSpace(string(req.Format)) != "null" {
|
||||
if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||
slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
if len(req.Format) > 0 {
|
||||
switch {
|
||||
case bytes.Equal(req.Format, []byte(`"json"`)):
|
||||
request["grammar"] = grammarJSON
|
||||
case bytes.HasPrefix(req.Format, []byte("{")):
|
||||
// User provided a JSON schema
|
||||
g := llama.SchemaToGrammar(req.Format)
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
}
|
||||
} else if schema, err := func() (llama.JsonSchema, error) {
|
||||
var schema llama.JsonSchema
|
||||
err := json.Unmarshal(req.Format, &schema)
|
||||
return schema, err
|
||||
}(); err == nil {
|
||||
request["grammar"] = schema.AsGrammar()
|
||||
} else {
|
||||
slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
|
||||
request["grammar"] = string(g)
|
||||
default:
|
||||
return errors.New(`invalid format: expected "json" or a JSON schema`)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
125
openai/openai.go
125
openai/openai.go
@@ -67,7 +67,7 @@ type ResponseFormat struct {
|
||||
}
|
||||
|
||||
type JsonSchema struct {
|
||||
Schema map[string]any `json:"schema"`
|
||||
Schema json.RawMessage `json:"schema"`
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
@@ -75,10 +75,15 @@ type EmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
StreamOptions *StreamOptions `json:"stream_options"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
Seed *int `json:"seed"`
|
||||
Stop any `json:"stop"`
|
||||
@@ -107,21 +112,23 @@ type ChatCompletionChunk struct {
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []ChunkChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
Seed *int `json:"seed"`
|
||||
Stop any `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature *float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Suffix string `json:"suffix"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
Seed *int `json:"seed"`
|
||||
Stop any `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
StreamOptions *StreamOptions `json:"stream_options"`
|
||||
Temperature *float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Suffix string `json:"suffix"`
|
||||
}
|
||||
|
||||
type Completion struct {
|
||||
@@ -141,6 +148,7 @@ type CompletionChunk struct {
|
||||
Choices []CompleteChunkChoice `json:"choices"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
@@ -197,6 +205,14 @@ func NewError(code int, message string) ErrorResponse {
|
||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||
}
|
||||
|
||||
func toUsage(r api.ChatResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
func toolCallId() string {
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, 8)
|
||||
@@ -246,11 +262,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
Usage: toUsage(r),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,6 +287,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||
}
|
||||
}
|
||||
|
||||
func toUsageGenerate(r api.GenerateResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||
return Completion{
|
||||
Id: id,
|
||||
@@ -292,11 +312,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
Usage: toUsageGenerate(r),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -495,11 +511,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
format = json.RawMessage(`"json"`)
|
||||
case "json_schema":
|
||||
if r.ResponseFormat.JsonSchema != nil {
|
||||
schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal json schema: %w", err)
|
||||
}
|
||||
format = schema
|
||||
format = r.ResponseFormat.JsonSchema.Schema
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -570,14 +582,16 @@ type BaseWriter struct {
|
||||
}
|
||||
|
||||
type ChatWriter struct {
|
||||
stream bool
|
||||
id string
|
||||
stream bool
|
||||
streamOptions *StreamOptions
|
||||
id string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type CompleteWriter struct {
|
||||
stream bool
|
||||
id string
|
||||
stream bool
|
||||
streamOptions *StreamOptions
|
||||
id string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
@@ -620,7 +634,8 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// chat chunk
|
||||
if w.stream {
|
||||
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
||||
c := toChunk(w.id, chatResponse)
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -632,6 +647,19 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
if chatResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := toUsage(chatResponse)
|
||||
c.Usage = &u
|
||||
c.Choices = []ChunkChoice{}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -669,7 +697,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// completion chunk
|
||||
if w.stream {
|
||||
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
|
||||
c := toCompleteChunk(w.id, generateResponse)
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
c.Usage = &Usage{}
|
||||
}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -681,6 +713,19 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
if generateResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := toUsageGenerate(generateResponse)
|
||||
c.Usage = &u
|
||||
c.Choices = []CompleteChunkChoice{}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -843,9 +888,10 @@ func CompletionsMiddleware() gin.HandlerFunc {
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &CompleteWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
@@ -925,9 +971,10 @@ func ChatMiddleware() gin.HandlerFunc {
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ChatWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
@@ -112,6 +112,45 @@ func TestChatMiddleware(t *testing.T) {
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with streaming usage",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": true,
|
||||
"stream_options": {"include_usage": true},
|
||||
"max_tokens": 999,
|
||||
"seed": 123,
|
||||
"stop": ["\n", "stop"],
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
"response_format": {"type": "json_object"}
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||
"seed": 123.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
},
|
||||
Format: json.RawMessage(`"json"`),
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with image content",
|
||||
body: `{
|
||||
@@ -363,6 +402,55 @@ func TestCompletionsMiddleware(t *testing.T) {
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler stream",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"stream": true,
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler stream with usage",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"stream": true,
|
||||
"stream_options": {"include_usage": true},
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler error forwarding",
|
||||
body: `{
|
||||
|
||||
@@ -376,6 +376,10 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
switch command {
|
||||
case "model", "adapter":
|
||||
if name := model.ParseName(c.Args); name.IsValid() && command == "model" {
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
baseLayers, err = parseFromModel(ctx, name, fn)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type ModelPath struct {
|
||||
@@ -93,11 +95,16 @@ func (mp ModelPath) GetShortTagname() string {
|
||||
|
||||
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
||||
func (mp ModelPath) GetManifestPath() (string, error) {
|
||||
if p := filepath.Join(mp.Registry, mp.Namespace, mp.Repository, mp.Tag); filepath.IsLocal(p) {
|
||||
return filepath.Join(envconfig.Models(), "manifests", p), nil
|
||||
name := model.Name{
|
||||
Host: mp.Registry,
|
||||
Namespace: mp.Namespace,
|
||||
Model: mp.Repository,
|
||||
Tag: mp.Tag,
|
||||
}
|
||||
|
||||
return "", errModelPathInvalid
|
||||
if !name.IsValid() {
|
||||
return "", fs.ErrNotExist
|
||||
}
|
||||
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
|
||||
}
|
||||
|
||||
func (mp ModelPath) BaseURL() *url.URL {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -155,10 +154,3 @@ func TestParseModelPath(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsecureModelpath(t *testing.T) {
|
||||
mp := ParseModelPath("../../..:something")
|
||||
if _, err := mp.GetManifestPath(); !errors.Is(err, errModelPathInvalid) {
|
||||
t.Errorf("expected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
112
server/routes.go
112
server/routes.go
@@ -9,6 +9,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
@@ -120,10 +121,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
// what the API currently returns until we can change it.
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
// We cannot currently consolidate this into GetModel because all we'll
|
||||
// induce infinite recursion given the current code structure.
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
model, err := GetModel(name.String())
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
case errors.Is(err, fs.ErrNotExist):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
case err.Error() == "invalid model name":
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -157,7 +174,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
caps = append(caps, CapabilityInsert)
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||
return
|
||||
@@ -386,7 +403,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
name, err := getExistingName(model.ParseName(req.Model))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
@@ -489,7 +512,13 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
@@ -582,11 +611,11 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var model string
|
||||
var mname string
|
||||
if req.Model != "" {
|
||||
model = req.Model
|
||||
mname = req.Model
|
||||
} else if req.Name != "" {
|
||||
model = req.Name
|
||||
mname = req.Name
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
@@ -606,7 +635,13 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := PushModel(ctx, model, regOpts, fn); err != nil {
|
||||
name, err := getExistingName(model.ParseName(mname))
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
if err := PushModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -619,17 +654,29 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
// getExistingName returns the original, on disk name if the input name is a
|
||||
// case-insensitive match, otherwise it returns the input name.
|
||||
// getExistingName searches the models directory for the longest prefix match of
|
||||
// the input name and returns the input name with all existing parts replaced
|
||||
// with each part found. If no parts are found, the input name is returned as
|
||||
// is.
|
||||
func getExistingName(n model.Name) (model.Name, error) {
|
||||
var zero model.Name
|
||||
existing, err := Manifests(true)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
var set model.Name // tracks parts already canonicalized
|
||||
for e := range existing {
|
||||
if n.EqualFold(e) {
|
||||
return e, nil
|
||||
if set.Host == "" && strings.EqualFold(e.Host, n.Host) {
|
||||
n.Host = e.Host
|
||||
}
|
||||
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
|
||||
n.Namespace = e.Namespace
|
||||
}
|
||||
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
|
||||
n.Model = e.Model
|
||||
}
|
||||
if set.Tag == "" && strings.EqualFold(e.Tag, n.Tag) {
|
||||
n.Tag = e.Tag
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
@@ -658,7 +705,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if r.Path == "" && r.Modelfile == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or Modelfile are required"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -722,6 +769,12 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
n, err := getExistingName(n)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
switch {
|
||||
@@ -782,7 +835,16 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
m, err := GetModel(req.Model)
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
return nil, errModelPathInvalid
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err := GetModel(name.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -805,12 +867,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||
}
|
||||
|
||||
n := model.ParseName(req.Model)
|
||||
if !n.IsValid() {
|
||||
return nil, errors.New("invalid model name")
|
||||
}
|
||||
|
||||
manifest, err := ParseNamedManifest(n)
|
||||
manifest, err := ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1431,7 +1488,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
caps = append(caps, CapabilityTools)
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||
return
|
||||
|
||||
@@ -719,7 +719,7 @@ func TestGenerate(t *testing.T) {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -514,6 +514,8 @@ func TestManifestCaseSensitivity(t *testing.T) {
|
||||
|
||||
wantStableName := name()
|
||||
|
||||
t.Logf("stable name: %s", wantStableName)
|
||||
|
||||
// checkManifestList tests that there is strictly one manifest in the
|
||||
// models directory, and that the manifest is for the model under test.
|
||||
checkManifestList := func() {
|
||||
@@ -601,6 +603,18 @@ func TestManifestCaseSensitivity(t *testing.T) {
|
||||
Destination: name(),
|
||||
}))
|
||||
checkManifestList()
|
||||
|
||||
t.Logf("pushing")
|
||||
rr := createRequest(t, s.PushHandler, api.PushRequest{
|
||||
Model: name(),
|
||||
Insecure: true,
|
||||
Username: "alice",
|
||||
Password: "x",
|
||||
})
|
||||
checkOK(rr)
|
||||
if !strings.Contains(rr.Body.String(), `"status":"success"`) {
|
||||
t.Errorf("got = %q, want success", rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestShow(t *testing.T) {
|
||||
|
||||
@@ -223,12 +223,12 @@ func (n Name) String() string {
|
||||
func (n Name) DisplayShortest() string {
|
||||
var sb strings.Builder
|
||||
|
||||
if n.Host != defaultHost {
|
||||
if !strings.EqualFold(n.Host, defaultHost) {
|
||||
sb.WriteString(n.Host)
|
||||
sb.WriteByte('/')
|
||||
sb.WriteString(n.Namespace)
|
||||
sb.WriteByte('/')
|
||||
} else if n.Namespace != defaultNamespace {
|
||||
} else if !strings.EqualFold(n.Namespace, defaultNamespace) {
|
||||
sb.WriteString(n.Namespace)
|
||||
sb.WriteByte('/')
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user