diff --git a/server/routes.go b/server/routes.go index b19a40fbc..977a13ff2 100644 --- a/server/routes.go +++ b/server/routes.go @@ -752,9 +752,15 @@ func (s *Server) EmbedHandler(c *gin.Context) { return err } // TODO: this first normalization should be done by the model - embedding = normalize(embedding) + embedding, err = normalize(embedding) + if err != nil { + return err + } if req.Dimensions > 0 && req.Dimensions < len(embedding) { - embedding = normalize(embedding[:req.Dimensions]) + embedding, err = normalize(embedding[:req.Dimensions]) + if err != nil { + return err + } } embeddings[i] = embedding atomic.AddUint64(&totalTokens, uint64(tokenCount)) @@ -787,9 +793,12 @@ func (s *Server) EmbedHandler(c *gin.Context) { c.JSON(http.StatusOK, resp) } -func normalize(vec []float32) []float32 { +func normalize(vec []float32) ([]float32, error) { var sum float32 for _, v := range vec { + if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) { + return nil, errors.New("embedding contains NaN or Inf values") + } sum += v * v } @@ -797,7 +806,7 @@ func normalize(vec []float32) []float32 { for i := range vec { vec[i] *= norm } - return vec + return vec, nil } func (s *Server) EmbeddingsHandler(c *gin.Context) { @@ -2395,4 +2404,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message { } return msgs } - diff --git a/server/routes_test.go b/server/routes_test.go index 39d4f290d..e470b9384 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -723,15 +723,20 @@ func TestShow(t *testing.T) { func TestNormalize(t *testing.T) { type testCase struct { - input []float32 + input []float32 + expectError bool } testCases := []testCase{ - {input: []float32{1}}, - {input: []float32{0, 1, 2, 3}}, - {input: []float32{0.1, 0.2, 0.3}}, - {input: []float32{-0.1, 0.2, 0.3, -0.4}}, - {input: []float32{0, 0, 0}}, + {input: []float32{1}, expectError: false}, + {input: []float32{0, 1, 2, 3}, expectError: false}, + {input: []float32{0.1, 0.2, 0.3}, expectError: false}, + {input: []float32{-0.1, 0.2, 0.3, -0.4}, expectError: false}, + {input: []float32{0, 0, 0}, expectError: false}, + {input: []float32{float32(math.NaN()), 0.2, 0.3}, expectError: true}, + {input: []float32{0.1, float32(math.NaN()), 0.3}, expectError: true}, + {input: []float32{float32(math.Inf(1)), 0.2, 0.3}, expectError: true}, + {input: []float32{float32(math.Inf(-1)), 0.2, 0.3}, expectError: true}, } isNormalized := func(vec []float32) (res bool) { @@ -748,9 +753,18 @@ func TestNormalize(t *testing.T) { for _, tc := range testCases { t.Run("", func(t *testing.T) { - normalized := normalize(tc.input) - if !isNormalized(normalized) { - t.Errorf("Vector %v is not normalized", tc.input) + normalized, err := normalize(tc.input) + if tc.expectError { + if err == nil { + t.Errorf("Expected error for input %v, but got none", tc.input) + } + } else { + if err != nil { + t.Errorf("Unexpected error for input %v: %v", tc.input, err) + } + if !isNormalized(normalized) { + t.Errorf("Vector %v is not normalized", tc.input) + } } }) }