Truncation Integration Tests

This commit is contained in:
Roy Han
2024-07-01 16:26:30 -07:00
parent e068e7f698
commit 1a0c8b363c
3 changed files with 105 additions and 20 deletions

View File

@@ -395,7 +395,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
truncate := func(s string) (string, error) {
checkFit := func(s string, truncate bool) (string, error) {
tokens, err := runner.llama.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -403,8 +403,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
if len(tokens) > opts.NumCtx {
tokens = tokens[:opts.NumCtx]
return runner.llama.Detokenize(c.Request.Context(), tokens)
if truncate {
tokens = tokens[:opts.NumCtx]
return runner.llama.Detokenize(c.Request.Context(), tokens)
} else {
return "", fmt.Errorf("input length exceeds maximum context length")
}
}
return s, nil
@@ -418,12 +422,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
return
}
if *req.Truncate {
reqEmbed, err = truncate(reqEmbed)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
case []any:
@@ -435,12 +437,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
reqEmbedArray := make([]string, len(reqEmbed))
for i, v := range reqEmbed {
if s, ok := v.(string); ok {
if *req.Truncate {
s, err = truncate(s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s, err = checkFit(s, *req.Truncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
reqEmbedArray[i] = s
} else {