From c355906cd7dc7a18914a6a59f6f333f0f1622d39 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 11 Jul 2023 17:04:21 +0200 Subject: [PATCH] embedding back-end --- api/types.go | 11 +++++++++++ llama/llama.go | 24 ++++++++++++++++++++++++ server/routes.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+) diff --git a/api/types.go b/api/types.go index e0f1f4da3..772de5fa2 100644 --- a/api/types.go +++ b/api/types.go @@ -89,3 +89,14 @@ func DefaultOptions() Options { NumThread: runtime.NumCPU(), } } + +type EmbeddingRequest struct { + Model string `json:"model"` + Input string `json:"input"` + + Options `json:"options"` +} + +type EmbeddingResponse struct { + Embedding []float32 `json:"embedding"` +} diff --git a/llama/llama.go b/llama/llama.go index b64acd7e8..eec3bf6ac 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -81,6 +81,7 @@ import ( "errors" "io" "os" + "reflect" "strings" "unsafe" @@ -232,3 +233,26 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s return 0, io.EOF } + +func (llm *llama) Embed(input string) ([]float32, error) { + if !llm.EmbeddingOnly { + return nil, errors.New("llama: embedding not enabled") + } + if tokens := llm.tokenize(input); tokens != nil { + if retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 { + return nil, errors.New("llama: eval") + } + n := int(C.llama_n_embd(llm.ctx)) + embedPtr := C.llama_get_embeddings(llm.ctx) + header := reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(embedPtr)), + Len: n, + Cap: n, + } + embedSlice := *(*[]float32)(unsafe.Pointer(&header)) + + return embedSlice, nil + } + + return nil, errors.New("llama: tokenize embedding") +} diff --git a/server/routes.go b/server/routes.go index 47551f15a..ac5a17a19 100644 --- a/server/routes.go +++ b/server/routes.go @@ -117,6 +117,46 @@ func generate(c *gin.Context) { } } +func embedding(c *gin.Context) { + req := api.EmbeddingRequest{ + Options: api.DefaultOptions(), + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } + + if remoteModel, _ := getRemote(req.Model); remoteModel != nil { + req.Model = remoteModel.FullName() + } + if _, err := os.Stat(req.Model); err != nil { + if !errors.Is(err, os.ErrNotExist) { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } + req.Model = path.Join(cacheDir(), "models", req.Model+".bin") + } + + req.Options.EmbeddingOnly = true // this is required for this endpoint + llm, err := llama.New(req.Model, req.Options) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + defer llm.Close() + + embedding, err := llm.Embed(req.Input) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + resp := api.EmbeddingResponse{ + Embedding: embedding, + } + c.JSON(http.StatusOK, resp) +} + func Serve(ln net.Listener) error { r := gin.Default() @@ -165,6 +205,8 @@ func Serve(ln net.Listener) error { }) }) + r.POST("/api/embedding", embedding) + r.POST("/api/generate", generate) log.Printf("Listening on %s", ln.Addr())