diff --git a/api/client.go b/api/client.go index ccbcbf6bc..cc5e00268 100644 --- a/api/client.go +++ b/api/client.go @@ -98,3 +98,33 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc return fn(resp) }) } + +func (c *Client) Embedding(ctx context.Context, req EmbeddingRequest) (*EmbeddingResponse, error) { + var buf *bytes.Buffer + bts, err := json.Marshal(&req) + if err != nil { + return nil, err + } + + buf = bytes.NewBuffer(bts) + + request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base.JoinPath("/api/embedding").String(), buf) + if err != nil { + return nil, err + } + + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/json") + + response, err := http.DefaultClient.Do(request) + if err != nil { + return nil, err + } + defer response.Body.Close() + + var resp EmbeddingResponse + if err := json.NewDecoder(response.Body).Decode(&resp); err != nil { + return nil, fmt.Errorf("unmarshal embedding: %w", err) + } + return &resp, nil +} diff --git a/api/types.go b/api/types.go index 772de5fa2..f1b234bc3 100644 --- a/api/types.go +++ b/api/types.go @@ -98,5 +98,5 @@ type EmbeddingRequest struct { } type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` + Embedding []float64 `json:"embedding"` } diff --git a/cmd/cmd.go b/cmd/cmd.go index 8421b8f58..20a050dfe 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -15,6 +15,7 @@ import ( "github.com/schollz/progressbar/v3" "github.com/spf13/cobra" "golang.org/x/term" + "gonum.org/v1/gonum/mat" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/server" @@ -67,7 +68,8 @@ func RunGenerate(_ *cobra.Command, args []string) error { // join all args into a single prompt prompt := strings.Join(args[1:], " ") if len(args) > 1 { - return generate(args[0], prompt) + _, err := generate(args[0], prompt) + return err } if term.IsTerminal(int(os.Stdin.Fd())) { @@ -77,7 +79,67 @@ func RunGenerate(_ *cobra.Command, args []string) error { return generateBatch(args[0]) } -func generate(model, prompt string) error { +// TODO: rather than setting this, add an endpoint to the server to tokenize the prompt so we can get the real length of the prompt +const maxChars = 250 // currently the max default context in the server is 512, this sets characters to a range that hopefully won't exceed the token length + +// stuffPrompt adds adds more context to the prompt from the current session +func stuffPrompt(prompt string, similar VectorSlice) string { + if len(prompt) >= maxChars { + return prompt + } + for _, s := range similar { + userInput := fmt.Sprintf(". I previously stated %q", s.UserInput) + if len(prompt)+len(userInput) < maxChars { + prompt += userInput + } + modelResponse := fmt.Sprintf(". You previously responded %q", s.ModelResponse) + if len(prompt)+len(modelResponse) < maxChars { + prompt += modelResponse + } + } + return prompt +} + +// generateWithEmbeddings adds additional context to the prompt from the current session +func generateWithEmbeddings(model string, embeddings *VectorSlice, prompt string) error { + input := prompt + client := api.NewClient() + // get the embedding of the current prompt to find similar prompts stored in memory + e, err := client.Embedding(context.Background(), api.EmbeddingRequest{Model: model, Input: prompt}) + if err != nil { + return err + } + embedding := mat.NewVecDense(len(e.Embedding), e.Embedding) + similar := embeddings.NearestNeighbors(embedding, 2) + + prompt = stuffPrompt(input, similar) + + generated, err := generate(model, prompt) + if err != nil { + return err + } + + go func() { + fullText := fmt.Sprintf("%s %s", prompt, generated) + // if the prompt got stuffed, only add the original input to avoid nesting user inputs + if prompt != input { + fullText = fmt.Sprintf("%s %s", input, generated) + } + e, err = client.Embedding(context.Background(), api.EmbeddingRequest{Model: model, Input: fullText}) + if err != nil { + return + } + embeddings.Add(Vector{ + UserInput: prompt, + ModelResponse: generated, + Data: mat.NewVecDense(len(e.Embedding), e.Embedding), + }) + }() + return nil +} + +func generate(model, prompt string) (string, error) { + result := "" if len(strings.TrimSpace(prompt)) > 0 { client := api.NewClient() @@ -107,25 +169,28 @@ func generate(model, prompt string) error { } fmt.Print(resp.Response) + result += resp.Response return nil } if err := client.Generate(context.Background(), &request, fn); err != nil { - return err + return "", err } fmt.Println() fmt.Println() } - return nil + return result, nil } +// generateInteractive runs the generator in interactive mode which has a memory of previous prompts func generateInteractive(model string) error { fmt.Print(">>> ") scanner := bufio.NewScanner(os.Stdin) + embeddings := &VectorSlice{} for scanner.Scan() { - if err := generate(model, scanner.Text()); err != nil { + if err := generateWithEmbeddings(model, embeddings, scanner.Text()); err != nil { return err } @@ -140,7 +205,7 @@ func generateBatch(model string) error { for scanner.Scan() { prompt := scanner.Text() fmt.Printf(">>> %s\n", prompt) - if err := generate(model, prompt); err != nil { + if _, err := generate(model, prompt); err != nil { return err } } diff --git a/cmd/vector.go b/cmd/vector.go new file mode 100644 index 000000000..00e1ea121 --- /dev/null +++ b/cmd/vector.go @@ -0,0 +1,67 @@ +package cmd + +import ( + "sort" + + "gonum.org/v1/gonum/mat" +) + +type Vector struct { + Data *mat.VecDense // the embedding vector + UserInput string // the user input segment of the text + ModelResponse string // the model response segment of the text +} + +// VectorSimilarity is a vector and its similarity to another vector +type VectorSimilarity struct { + Vector Vector + Similarity float64 +} + +type BySimilarity []VectorSimilarity + +func (a BySimilarity) Len() int { return len(a) } +func (a BySimilarity) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a BySimilarity) Less(i, j int) bool { return a[i].Similarity > a[j].Similarity } + +// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors. +// This value will range from -1 to 1, where 1 means the vectors are identical. +func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 { + dotProduct := mat.Dot(vec1, vec2) + norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2) + + if norms == 0 { + return 0 + } + return dotProduct / norms +} + +type VectorSlice []Vector + +func (vs *VectorSlice) Add(v Vector) { + *vs = append(*vs, v) +} + +func (vs *VectorSlice) Length() int { + return len(*vs) +} + +func (vs *VectorSlice) NearestNeighbors(embedding *mat.VecDense, n int) VectorSlice { + if vs.Length() == 0 { + return VectorSlice{} + } + similarities := make([]VectorSimilarity, vs.Length()) + for i, v := range *vs { + similarity := cosineSimilarity(embedding, v.Data) + similarities[i] = VectorSimilarity{Vector: v, Similarity: similarity} + } + sort.Sort(BySimilarity(similarities)) + if len(similarities) < n { + n = len(similarities) + } + result := make(VectorSlice, n) + for i := 0; i < n; i++ { + result[i] = similarities[i].Vector + } + return result +} diff --git a/go.mod b/go.mod index 8beb32bd9..496ff7f6c 100644 --- a/go.mod +++ b/go.mod @@ -43,6 +43,7 @@ require ( golang.org/x/sys v0.10.0 // indirect golang.org/x/term v0.10.0 golang.org/x/text v0.10.0 // indirect + gonum.org/v1/gonum v0.13.0 google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9189b115b..a7816d682 100644 --- a/go.sum +++ b/go.sum @@ -131,6 +131,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM= +gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= diff --git a/llama/llama.go b/llama/llama.go index eec3bf6ac..b3184a495 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -234,7 +234,7 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s return 0, io.EOF } -func (llm *llama) Embed(input string) ([]float32, error) { +func (llm *llama) Embed(input string) ([]float64, error) { if !llm.EmbeddingOnly { return nil, errors.New("llama: embedding not enabled") } @@ -249,7 +249,7 @@ func (llm *llama) Embed(input string) ([]float32, error) { Len: n, Cap: n, } - embedSlice := *(*[]float32)(unsafe.Pointer(&header)) + embedSlice := *(*[]float64)(unsafe.Pointer(&header)) return embedSlice, nil }