use embeddings to give chat additional context

This commit is contained in:
Bruce MacDonald 2023-07-11 19:49:29 +02:00
parent c355906cd7
commit 6ee1822105
7 changed files with 174 additions and 9 deletions

View File

@ -98,3 +98,33 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
return fn(resp)
})
}
func (c *Client) Embedding(ctx context.Context, req EmbeddingRequest) (*EmbeddingResponse, error) {
var buf *bytes.Buffer
bts, err := json.Marshal(&req)
if err != nil {
return nil, err
}
buf = bytes.NewBuffer(bts)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base.JoinPath("/api/embedding").String(), buf)
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
response, err := http.DefaultClient.Do(request)
if err != nil {
return nil, err
}
defer response.Body.Close()
var resp EmbeddingResponse
if err := json.NewDecoder(response.Body).Decode(&resp); err != nil {
return nil, fmt.Errorf("unmarshal embedding: %w", err)
}
return &resp, nil
}

View File

@ -98,5 +98,5 @@ type EmbeddingRequest struct {
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
Embedding []float64 `json:"embedding"`
}

View File

@ -15,6 +15,7 @@ import (
"github.com/schollz/progressbar/v3"
"github.com/spf13/cobra"
"golang.org/x/term"
"gonum.org/v1/gonum/mat"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/server"
@ -67,7 +68,8 @@ func RunGenerate(_ *cobra.Command, args []string) error {
// join all args into a single prompt
prompt := strings.Join(args[1:], " ")
if len(args) > 1 {
return generate(args[0], prompt)
_, err := generate(args[0], prompt)
return err
}
if term.IsTerminal(int(os.Stdin.Fd())) {
@ -77,7 +79,67 @@ func RunGenerate(_ *cobra.Command, args []string) error {
return generateBatch(args[0])
}
func generate(model, prompt string) error {
// TODO: rather than setting this, add an endpoint to the server to tokenize the prompt so we can get the real length of the prompt
const maxChars = 250 // currently the max default context in the server is 512, this sets characters to a range that hopefully won't exceed the token length
// stuffPrompt adds adds more context to the prompt from the current session
func stuffPrompt(prompt string, similar VectorSlice) string {
if len(prompt) >= maxChars {
return prompt
}
for _, s := range similar {
userInput := fmt.Sprintf(". I previously stated %q", s.UserInput)
if len(prompt)+len(userInput) < maxChars {
prompt += userInput
}
modelResponse := fmt.Sprintf(". You previously responded %q", s.ModelResponse)
if len(prompt)+len(modelResponse) < maxChars {
prompt += modelResponse
}
}
return prompt
}
// generateWithEmbeddings adds additional context to the prompt from the current session
func generateWithEmbeddings(model string, embeddings *VectorSlice, prompt string) error {
input := prompt
client := api.NewClient()
// get the embedding of the current prompt to find similar prompts stored in memory
e, err := client.Embedding(context.Background(), api.EmbeddingRequest{Model: model, Input: prompt})
if err != nil {
return err
}
embedding := mat.NewVecDense(len(e.Embedding), e.Embedding)
similar := embeddings.NearestNeighbors(embedding, 2)
prompt = stuffPrompt(input, similar)
generated, err := generate(model, prompt)
if err != nil {
return err
}
go func() {
fullText := fmt.Sprintf("%s %s", prompt, generated)
// if the prompt got stuffed, only add the original input to avoid nesting user inputs
if prompt != input {
fullText = fmt.Sprintf("%s %s", input, generated)
}
e, err = client.Embedding(context.Background(), api.EmbeddingRequest{Model: model, Input: fullText})
if err != nil {
return
}
embeddings.Add(Vector{
UserInput: prompt,
ModelResponse: generated,
Data: mat.NewVecDense(len(e.Embedding), e.Embedding),
})
}()
return nil
}
func generate(model, prompt string) (string, error) {
result := ""
if len(strings.TrimSpace(prompt)) > 0 {
client := api.NewClient()
@ -107,25 +169,28 @@ func generate(model, prompt string) error {
}
fmt.Print(resp.Response)
result += resp.Response
return nil
}
if err := client.Generate(context.Background(), &request, fn); err != nil {
return err
return "", err
}
fmt.Println()
fmt.Println()
}
return nil
return result, nil
}
// generateInteractive runs the generator in interactive mode which has a memory of previous prompts
func generateInteractive(model string) error {
fmt.Print(">>> ")
scanner := bufio.NewScanner(os.Stdin)
embeddings := &VectorSlice{}
for scanner.Scan() {
if err := generate(model, scanner.Text()); err != nil {
if err := generateWithEmbeddings(model, embeddings, scanner.Text()); err != nil {
return err
}
@ -140,7 +205,7 @@ func generateBatch(model string) error {
for scanner.Scan() {
prompt := scanner.Text()
fmt.Printf(">>> %s\n", prompt)
if err := generate(model, prompt); err != nil {
if _, err := generate(model, prompt); err != nil {
return err
}
}

67
cmd/vector.go Normal file
View File

@ -0,0 +1,67 @@
package cmd
import (
"sort"
"gonum.org/v1/gonum/mat"
)
type Vector struct {
Data *mat.VecDense // the embedding vector
UserInput string // the user input segment of the text
ModelResponse string // the model response segment of the text
}
// VectorSimilarity is a vector and its similarity to another vector
type VectorSimilarity struct {
Vector Vector
Similarity float64
}
type BySimilarity []VectorSimilarity
func (a BySimilarity) Len() int { return len(a) }
func (a BySimilarity) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a BySimilarity) Less(i, j int) bool { return a[i].Similarity > a[j].Similarity }
// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
// This value will range from -1 to 1, where 1 means the vectors are identical.
func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
dotProduct := mat.Dot(vec1, vec2)
norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
if norms == 0 {
return 0
}
return dotProduct / norms
}
type VectorSlice []Vector
func (vs *VectorSlice) Add(v Vector) {
*vs = append(*vs, v)
}
func (vs *VectorSlice) Length() int {
return len(*vs)
}
func (vs *VectorSlice) NearestNeighbors(embedding *mat.VecDense, n int) VectorSlice {
if vs.Length() == 0 {
return VectorSlice{}
}
similarities := make([]VectorSimilarity, vs.Length())
for i, v := range *vs {
similarity := cosineSimilarity(embedding, v.Data)
similarities[i] = VectorSimilarity{Vector: v, Similarity: similarity}
}
sort.Sort(BySimilarity(similarities))
if len(similarities) < n {
n = len(similarities)
}
result := make(VectorSlice, n)
for i := 0; i < n; i++ {
result[i] = similarities[i].Vector
}
return result
}

1
go.mod
View File

@ -43,6 +43,7 @@ require (
golang.org/x/sys v0.10.0 // indirect
golang.org/x/term v0.10.0
golang.org/x/text v0.10.0 // indirect
gonum.org/v1/gonum v0.13.0
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

2
go.sum
View File

@ -131,6 +131,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM=
gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=

View File

@ -234,7 +234,7 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s
return 0, io.EOF
}
func (llm *llama) Embed(input string) ([]float32, error) {
func (llm *llama) Embed(input string) ([]float64, error) {
if !llm.EmbeddingOnly {
return nil, errors.New("llama: embedding not enabled")
}
@ -249,7 +249,7 @@ func (llm *llama) Embed(input string) ([]float32, error) {
Len: n,
Cap: n,
}
embedSlice := *(*[]float32)(unsafe.Pointer(&header))
embedSlice := *(*[]float64)(unsafe.Pointer(&header))
return embedSlice, nil
}