Truncation Integration Tests
This commit is contained in:
@@ -395,7 +395,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
truncate := func(s string) (string, error) {
|
||||
checkFit := func(s string, truncate bool) (string, error) {
|
||||
tokens, err := runner.llama.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -403,8 +403,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(tokens) > opts.NumCtx {
|
||||
tokens = tokens[:opts.NumCtx]
|
||||
return runner.llama.Detokenize(c.Request.Context(), tokens)
|
||||
if truncate {
|
||||
tokens = tokens[:opts.NumCtx]
|
||||
return runner.llama.Detokenize(c.Request.Context(), tokens)
|
||||
} else {
|
||||
return "", fmt.Errorf("input length exceeds maximum context length")
|
||||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
@@ -418,12 +422,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
|
||||
return
|
||||
}
|
||||
if *req.Truncate {
|
||||
reqEmbed, err = truncate(reqEmbed)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
||||
case []any:
|
||||
@@ -435,12 +437,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
reqEmbedArray := make([]string, len(reqEmbed))
|
||||
for i, v := range reqEmbed {
|
||||
if s, ok := v.(string); ok {
|
||||
if *req.Truncate {
|
||||
s, err = truncate(s)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
s, err = checkFit(s, *req.Truncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
reqEmbedArray[i] = s
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user