Compare commits

..

7 Commits

Author SHA1 Message Date
Josh Yan
7066120aaf refactor err 2024-07-22 11:34:01 -07:00
Josh Yan
ca1fbc5789 cmt 2024-07-19 15:23:30 -07:00
Josh Yan
aaec2be2ee gin header 2024-07-17 12:12:43 -07:00
Josh Yan
9b5bf861dd use new err 2024-07-17 11:35:34 -07:00
Josh Yan
3e89435605 bad request to templ err 2024-07-17 09:59:20 -07:00
Josh Yan
f7b6cd7934 tests 2024-07-16 17:31:12 -07:00
Josh Yan
5bfb07b500 validate template 2024-07-16 17:11:39 -07:00
4 changed files with 52 additions and 82 deletions

View File

@@ -27,11 +27,6 @@ chat_completion = client.chat.completions.create(
], ],
model='llama3', model='llama3',
) )
completion = client.completions.create(
model="llama3",
prompt="Say this is a test"
)
``` ```
### OpenAI JavaScript library ### OpenAI JavaScript library
@@ -50,11 +45,6 @@ const chatCompletion = await openai.chat.completions.create({
messages: [{ role: 'user', content: 'Say this is a test' }], messages: [{ role: 'user', content: 'Say this is a test' }],
model: 'llama3', model: 'llama3',
}) })
const completion = await openai.completions.create({
model: "llama3",
prompt: "Say this is a test.",
})
``` ```
### `curl` ### `curl`
@@ -76,12 +66,6 @@ curl http://localhost:11434/v1/chat/completions \
] ]
}' }'
curl http://localhost:11434/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama3",
"prompt": "Say this is a test"
}'
``` ```
## Endpoints ## Endpoints
@@ -119,71 +103,8 @@ curl http://localhost:11434/v1/completions \
- [ ] `user` - [ ] `user`
- [ ] `n` - [ ] `n`
### `/v1/completions`
#### Supported features
- [x] Completions
- [x] Streaming
- [x] JSON mode
- [x] Reproducible outputs
- [ ] Logprobs
#### Supported request fields
- [x] `model`
- [x] `prompt`
- [x] `frequency_penalty`
- [x] `presence_penalty`
- [x] `seed`
- [x] `stop`
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `max_tokens`
- [x] `suffix`
- [ ] `best_of`
- [ ] `echo`
- [ ] `logit_bias`
- [ ] `user`
- [ ] `n`
#### Notes #### Notes
- `prompt` currently only accepts a string
### `/v1/completions`
#### Supported features
- [x] Completions
- [x] Streaming
- [x] JSON mode
- [x] Reproducible outputs
- [ ] Logprobs
#### Supported request fields
- [x] `model`
- [x] `prompt`
- [x] `frequency_penalty`
- [x] `presence_penalty`
- [x] `seed`
- [x] `stop`
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `max_tokens`
- [ ] `best_of`
- [ ] `echo`
- [ ] `suffix`
- [ ] `logit_bias`
- [ ] `user`
- [ ] `n`
#### Notes
- `prompt` currently only accepts a string
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached - `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
## Models ## Models

View File

@@ -492,6 +492,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
layers = append(layers, baseLayer.Layer) layers = append(layers, baseLayer.Layer)
} }
case "license", "template", "system": case "license", "template", "system":
if c.Name == "template" {
if _, err := template.Parse(c.Args); err != nil {
return fmt.Errorf("%w: %s", errBadTemplate, err)
}
}
if c.Name != "license" { if c.Name != "license" {
// replace // replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool { layers = slices.DeleteFunc(layers, func(layer *Layer) bool {

View File

@@ -56,6 +56,7 @@ func init() {
} }
var errRequired = errors.New("is required") var errRequired = errors.New("is required")
var errBadTemplate = errors.New("template error")
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions() opts := api.DefaultOptions()
@@ -613,7 +614,9 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
defer cancel() defer cancel()
quantization := cmp.Or(r.Quantize, r.Quantization) quantization := cmp.Or(r.Quantize, r.Quantization)
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil { if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); errors.Is(err, errBadTemplate) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
} else if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@@ -1201,11 +1204,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
return return
} }
case gin.H: case gin.H:
status, ok := r["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
if errorMsg, ok := r["error"].(string); ok { if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg}) c.JSON(status, gin.H{"error": errorMsg})
return return
} else { } else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"}) c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
return return
} }
default: default:

View File

@@ -491,6 +491,42 @@ func TestCreateTemplateSystem(t *testing.T) {
if string(system) != "Say bye!" { if string(system) != "Say bye!" {
t.Errorf("expected \"Say bye!\", actual %s", system) t.Errorf("expected \"Say bye!\", actual %s", system)
} }
t.Run("incomplete template", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with unclosed if", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with undefined function", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
} }
func TestCreateLicenses(t *testing.T) { func TestCreateLicenses(t *testing.T) {