embedding back-end

This commit is contained in:
Bruce MacDonald 2023-07-11 17:04:21 +02:00
parent 62620914e9
commit c355906cd7
3 changed files with 77 additions and 0 deletions

View File

@ -89,3 +89,14 @@ func DefaultOptions() Options {
NumThread: runtime.NumCPU(),
}
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input string `json:"input"`
Options `json:"options"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}

View File

@ -81,6 +81,7 @@ import (
"errors"
"io"
"os"
"reflect"
"strings"
"unsafe"
@ -232,3 +233,26 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s
return 0, io.EOF
}
func (llm *llama) Embed(input string) ([]float32, error) {
if !llm.EmbeddingOnly {
return nil, errors.New("llama: embedding not enabled")
}
if tokens := llm.tokenize(input); tokens != nil {
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
return nil, errors.New("llama: eval")
}
n := int(C.llama_n_embd(llm.ctx))
embedPtr := C.llama_get_embeddings(llm.ctx)
header := reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(embedPtr)),
Len: n,
Cap: n,
}
embedSlice := *(*[]float32)(unsafe.Pointer(&header))
return embedSlice, nil
}
return nil, errors.New("llama: tokenize embedding")
}

View File

@ -117,6 +117,46 @@ func generate(c *gin.Context) {
}
}
func embedding(c *gin.Context) {
req := api.EmbeddingRequest{
Options: api.DefaultOptions(),
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
}
if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
req.Model = remoteModel.FullName()
}
if _, err := os.Stat(req.Model); err != nil {
if !errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
}
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
}
req.Options.EmbeddingOnly = true // this is required for this endpoint
llm, err := llama.New(req.Model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer llm.Close()
embedding, err := llm.Embed(req.Input)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
resp := api.EmbeddingResponse{
Embedding: embedding,
}
c.JSON(http.StatusOK, resp)
}
func Serve(ln net.Listener) error {
r := gin.Default()
@ -165,6 +205,8 @@ func Serve(ln net.Listener) error {
})
})
r.POST("/api/embedding", embedding)
r.POST("/api/generate", generate)
log.Printf("Listening on %s", ln.Addr())