Compare commits

..

1 Commits

Author SHA1 Message Date
ParthSareen
11eecdde86 cmd: enable use of structured outputs 2024-12-12 15:54:08 -08:00
6 changed files with 52 additions and 189 deletions

View File

@@ -1038,14 +1038,15 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
return nil
}
if opts.Format == "json" {
opts.Format = `"` + opts.Format + `"`
var format json.RawMessage
if opts.Format != "" {
format = json.RawMessage(opts.Format)
}
req := &api.ChatRequest{
Model: opts.Model,
Messages: opts.Messages,
Format: json.RawMessage(opts.Format),
Format: format,
Options: opts.Options,
}
@@ -1127,8 +1128,9 @@ func generate(cmd *cobra.Command, opts runOptions) error {
}
}
if opts.Format == "json" {
opts.Format = `"` + opts.Format + `"`
var format json.RawMessage
if opts.Format != "" {
format = json.RawMessage(opts.Format)
}
request := api.GenerateRequest{
@@ -1136,7 +1138,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
Prompt: opts.Prompt,
Context: generateContext,
Images: opts.Images,
Format: json.RawMessage(opts.Format),
Format: format,
System: opts.System,
Options: opts.Options,
KeepAlive: opts.KeepAlive,
@@ -1353,7 +1355,7 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("verbose", false, "Show timings for response")
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)")
runCmd.Flags().String("format", "", `Response format ("json" or a JSON Schema)`)
stopCmd := &cobra.Command{
Use: "stop MODEL",

View File

@@ -261,11 +261,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
fmt.Println("Set 'quiet' mode.")
case "format":
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
if len(args) < 3 {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json or provide a JSON schema'")
} else if len(args) == 3 && (args[2] == "json" || args[2] == `"json"`) {
opts.Format = `"json"`
fmt.Println("Set format to 'json' mode.")
} else if len(args) > 3 && strings.HasPrefix(args[2], "{") {
opts.Format = strings.Join(args[2:], " ")
fmt.Printf("Set format to schema: \n'%s'.\n", opts.Format)
} else {
opts.Format = args[2]
fmt.Printf("Set format to '%s' mode.\n", args[2])
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json or provide a JSON schema'")
}
case "noformat":
opts.Format = ""

View File

@@ -233,8 +233,6 @@ curl http://localhost:11434/v1/embeddings \
- [x] `seed`
- [x] `stop`
- [x] `stream`
- [x] `stream_options`
- [x] `include_usage`
- [x] `temperature`
- [x] `top_p`
- [x] `max_tokens`
@@ -263,8 +261,6 @@ curl http://localhost:11434/v1/embeddings \
- [x] `seed`
- [x] `stop`
- [x] `stream`
- [x] `stream_options`
- [x] `include_usage`
- [x] `temperature`
- [x] `top_p`
- [x] `max_tokens`

View File

@@ -89,7 +89,6 @@ import (
_ "embed"
"errors"
"fmt"
"os"
"runtime"
"runtime/cgo"
"slices"
@@ -132,7 +131,7 @@ func llamaLog(level int32, text *C.char, _ unsafe.Pointer) {
return
}
fmt.Fprint(os.Stderr, C.GoString(text))
fmt.Print(C.GoString(text))
}
func GetModelArch(modelPath string) (string, error) {

View File

@@ -75,15 +75,10 @@ type EmbedRequest struct {
Model string `json:"model"`
}
type StreamOptions struct {
IncludeUsage bool `json:"include_usage"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
MaxTokens *int `json:"max_tokens"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
@@ -112,23 +107,21 @@ type ChatCompletionChunk struct {
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []ChunkChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
}
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
}
type Completion struct {
@@ -148,7 +141,6 @@ type CompletionChunk struct {
Choices []CompleteChunkChoice `json:"choices"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Usage *Usage `json:"usage,omitempty"`
}
type ToolCall struct {
@@ -205,14 +197,6 @@ func NewError(code int, message string) ErrorResponse {
return ErrorResponse{Error{Type: etype, Message: message}}
}
func toUsage(r api.ChatResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
}
}
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
@@ -262,7 +246,11 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
return nil
}(r.DoneReason),
}},
Usage: toUsage(r),
Usage: Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
}
}
@@ -287,14 +275,6 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
}
}
func toUsageGenerate(r api.GenerateResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
}
}
func toCompletion(id string, r api.GenerateResponse) Completion {
return Completion{
Id: id,
@@ -312,7 +292,11 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
return nil
}(r.DoneReason),
}},
Usage: toUsageGenerate(r),
Usage: Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
}
}
@@ -582,16 +566,14 @@ type BaseWriter struct {
}
type ChatWriter struct {
stream bool
streamOptions *StreamOptions
id string
stream bool
id string
BaseWriter
}
type CompleteWriter struct {
stream bool
streamOptions *StreamOptions
id string
stream bool
id string
BaseWriter
}
@@ -634,8 +616,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
// chat chunk
if w.stream {
c := toChunk(w.id, chatResponse)
d, err := json.Marshal(c)
d, err := json.Marshal(toChunk(w.id, chatResponse))
if err != nil {
return 0, err
}
@@ -647,19 +628,6 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
}
if chatResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsage(chatResponse)
c.Usage = &u
c.Choices = []ChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
@@ -697,11 +665,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
// completion chunk
if w.stream {
c := toCompleteChunk(w.id, generateResponse)
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &Usage{}
}
d, err := json.Marshal(c)
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
if err != nil {
return 0, err
}
@@ -713,19 +677,6 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
}
if generateResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsageGenerate(generateResponse)
c.Usage = &u
c.Choices = []CompleteChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
@@ -888,10 +839,9 @@ func CompletionsMiddleware() gin.HandlerFunc {
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
}
c.Writer = w
@@ -971,10 +921,9 @@ func ChatMiddleware() gin.HandlerFunc {
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
}
c.Writer = w

View File

@@ -112,45 +112,6 @@ func TestChatMiddleware(t *testing.T) {
Stream: &True,
},
},
{
name: "chat handler with streaming usage",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": true,
"stream_options": {"include_usage": true},
"max_tokens": 999,
"seed": 123,
"stop": ["\n", "stop"],
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
"response_format": {"type": "json_object"}
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
"seed": 123.0,
"stop": []any{"\n", "stop"},
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
},
Format: json.RawMessage(`"json"`),
Stream: &True,
},
},
{
name: "chat handler with image content",
body: `{
@@ -402,55 +363,6 @@ func TestCompletionsMiddleware(t *testing.T) {
Stream: &False,
},
},
{
name: "completions handler stream",
body: `{
"model": "test-model",
"prompt": "Hello",
"stream": true,
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 0.8,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &True,
},
},
{
name: "completions handler stream with usage",
body: `{
"model": "test-model",
"prompt": "Hello",
"stream": true,
"stream_options": {"include_usage": true},
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 0.8,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &True,
},
},
{
name: "completions handler error forwarding",
body: `{