embedding back-end
This commit is contained in:
parent
62620914e9
commit
c355906cd7
11
api/types.go
11
api/types.go
|
|
@ -89,3 +89,14 @@ func DefaultOptions() Options {
|
|||
NumThread: runtime.NumCPU(),
|
||||
}
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
|
||||
Options `json:"options"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ import (
|
|||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
|
|
@ -232,3 +233,26 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s
|
|||
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (llm *llama) Embed(input string) ([]float32, error) {
|
||||
if !llm.EmbeddingOnly {
|
||||
return nil, errors.New("llama: embedding not enabled")
|
||||
}
|
||||
if tokens := llm.tokenize(input); tokens != nil {
|
||||
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
|
||||
return nil, errors.New("llama: eval")
|
||||
}
|
||||
n := int(C.llama_n_embd(llm.ctx))
|
||||
embedPtr := C.llama_get_embeddings(llm.ctx)
|
||||
header := reflect.SliceHeader{
|
||||
Data: uintptr(unsafe.Pointer(embedPtr)),
|
||||
Len: n,
|
||||
Cap: n,
|
||||
}
|
||||
embedSlice := *(*[]float32)(unsafe.Pointer(&header))
|
||||
|
||||
return embedSlice, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("llama: tokenize embedding")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -117,6 +117,46 @@ func generate(c *gin.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
func embedding(c *gin.Context) {
|
||||
req := api.EmbeddingRequest{
|
||||
Options: api.DefaultOptions(),
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
|
||||
req.Model = remoteModel.FullName()
|
||||
}
|
||||
if _, err := os.Stat(req.Model); err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||
return
|
||||
}
|
||||
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
|
||||
}
|
||||
|
||||
req.Options.EmbeddingOnly = true // this is required for this endpoint
|
||||
llm, err := llama.New(req.Model, req.Options)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer llm.Close()
|
||||
|
||||
embedding, err := llm.Embed(req.Input)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp := api.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func Serve(ln net.Listener) error {
|
||||
r := gin.Default()
|
||||
|
||||
|
|
@ -165,6 +205,8 @@ func Serve(ln net.Listener) error {
|
|||
})
|
||||
})
|
||||
|
||||
r.POST("/api/embedding", embedding)
|
||||
|
||||
r.POST("/api/generate", generate)
|
||||
|
||||
log.Printf("Listening on %s", ln.Addr())
|
||||
|
|
|
|||
Loading…
Reference in New Issue