use embeddings to give chat additional context
This commit is contained in:
parent
c355906cd7
commit
6ee1822105
|
|
@ -98,3 +98,33 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
|
|||
return fn(resp)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) Embedding(ctx context.Context, req EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||
var buf *bytes.Buffer
|
||||
bts, err := json.Marshal(&req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf = bytes.NewBuffer(bts)
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base.JoinPath("/api/embedding").String(), buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
|
||||
response, err := http.DefaultClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
var resp EmbeddingResponse
|
||||
if err := json.NewDecoder(response.Body).Decode(&resp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal embedding: %w", err)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,5 +98,5 @@ type EmbeddingRequest struct {
|
|||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
|
|
|||
77
cmd/cmd.go
77
cmd/cmd.go
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/schollz/progressbar/v3"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
"gonum.org/v1/gonum/mat"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/server"
|
||||
|
|
@ -67,7 +68,8 @@ func RunGenerate(_ *cobra.Command, args []string) error {
|
|||
// join all args into a single prompt
|
||||
prompt := strings.Join(args[1:], " ")
|
||||
if len(args) > 1 {
|
||||
return generate(args[0], prompt)
|
||||
_, err := generate(args[0], prompt)
|
||||
return err
|
||||
}
|
||||
|
||||
if term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
|
|
@ -77,7 +79,67 @@ func RunGenerate(_ *cobra.Command, args []string) error {
|
|||
return generateBatch(args[0])
|
||||
}
|
||||
|
||||
func generate(model, prompt string) error {
|
||||
// TODO: rather than setting this, add an endpoint to the server to tokenize the prompt so we can get the real length of the prompt
|
||||
const maxChars = 250 // currently the max default context in the server is 512, this sets characters to a range that hopefully won't exceed the token length
|
||||
|
||||
// stuffPrompt adds adds more context to the prompt from the current session
|
||||
func stuffPrompt(prompt string, similar VectorSlice) string {
|
||||
if len(prompt) >= maxChars {
|
||||
return prompt
|
||||
}
|
||||
for _, s := range similar {
|
||||
userInput := fmt.Sprintf(". I previously stated %q", s.UserInput)
|
||||
if len(prompt)+len(userInput) < maxChars {
|
||||
prompt += userInput
|
||||
}
|
||||
modelResponse := fmt.Sprintf(". You previously responded %q", s.ModelResponse)
|
||||
if len(prompt)+len(modelResponse) < maxChars {
|
||||
prompt += modelResponse
|
||||
}
|
||||
}
|
||||
return prompt
|
||||
}
|
||||
|
||||
// generateWithEmbeddings adds additional context to the prompt from the current session
|
||||
func generateWithEmbeddings(model string, embeddings *VectorSlice, prompt string) error {
|
||||
input := prompt
|
||||
client := api.NewClient()
|
||||
// get the embedding of the current prompt to find similar prompts stored in memory
|
||||
e, err := client.Embedding(context.Background(), api.EmbeddingRequest{Model: model, Input: prompt})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
embedding := mat.NewVecDense(len(e.Embedding), e.Embedding)
|
||||
similar := embeddings.NearestNeighbors(embedding, 2)
|
||||
|
||||
prompt = stuffPrompt(input, similar)
|
||||
|
||||
generated, err := generate(model, prompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
fullText := fmt.Sprintf("%s %s", prompt, generated)
|
||||
// if the prompt got stuffed, only add the original input to avoid nesting user inputs
|
||||
if prompt != input {
|
||||
fullText = fmt.Sprintf("%s %s", input, generated)
|
||||
}
|
||||
e, err = client.Embedding(context.Background(), api.EmbeddingRequest{Model: model, Input: fullText})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
embeddings.Add(Vector{
|
||||
UserInput: prompt,
|
||||
ModelResponse: generated,
|
||||
Data: mat.NewVecDense(len(e.Embedding), e.Embedding),
|
||||
})
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func generate(model, prompt string) (string, error) {
|
||||
result := ""
|
||||
if len(strings.TrimSpace(prompt)) > 0 {
|
||||
client := api.NewClient()
|
||||
|
||||
|
|
@ -107,25 +169,28 @@ func generate(model, prompt string) error {
|
|||
}
|
||||
|
||||
fmt.Print(resp.Response)
|
||||
result += resp.Response
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := client.Generate(context.Background(), &request, fn); err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
return nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// generateInteractive runs the generator in interactive mode which has a memory of previous prompts
|
||||
func generateInteractive(model string) error {
|
||||
fmt.Print(">>> ")
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
embeddings := &VectorSlice{}
|
||||
for scanner.Scan() {
|
||||
if err := generate(model, scanner.Text()); err != nil {
|
||||
if err := generateWithEmbeddings(model, embeddings, scanner.Text()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -140,7 +205,7 @@ func generateBatch(model string) error {
|
|||
for scanner.Scan() {
|
||||
prompt := scanner.Text()
|
||||
fmt.Printf(">>> %s\n", prompt)
|
||||
if err := generate(model, prompt); err != nil {
|
||||
if _, err := generate(model, prompt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"gonum.org/v1/gonum/mat"
|
||||
)
|
||||
|
||||
type Vector struct {
|
||||
Data *mat.VecDense // the embedding vector
|
||||
UserInput string // the user input segment of the text
|
||||
ModelResponse string // the model response segment of the text
|
||||
}
|
||||
|
||||
// VectorSimilarity is a vector and its similarity to another vector
|
||||
type VectorSimilarity struct {
|
||||
Vector Vector
|
||||
Similarity float64
|
||||
}
|
||||
|
||||
type BySimilarity []VectorSimilarity
|
||||
|
||||
func (a BySimilarity) Len() int { return len(a) }
|
||||
func (a BySimilarity) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a BySimilarity) Less(i, j int) bool { return a[i].Similarity > a[j].Similarity }
|
||||
|
||||
// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
|
||||
// This value will range from -1 to 1, where 1 means the vectors are identical.
|
||||
func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
|
||||
dotProduct := mat.Dot(vec1, vec2)
|
||||
norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
|
||||
|
||||
if norms == 0 {
|
||||
return 0
|
||||
}
|
||||
return dotProduct / norms
|
||||
}
|
||||
|
||||
type VectorSlice []Vector
|
||||
|
||||
func (vs *VectorSlice) Add(v Vector) {
|
||||
*vs = append(*vs, v)
|
||||
}
|
||||
|
||||
func (vs *VectorSlice) Length() int {
|
||||
return len(*vs)
|
||||
}
|
||||
|
||||
func (vs *VectorSlice) NearestNeighbors(embedding *mat.VecDense, n int) VectorSlice {
|
||||
if vs.Length() == 0 {
|
||||
return VectorSlice{}
|
||||
}
|
||||
similarities := make([]VectorSimilarity, vs.Length())
|
||||
for i, v := range *vs {
|
||||
similarity := cosineSimilarity(embedding, v.Data)
|
||||
similarities[i] = VectorSimilarity{Vector: v, Similarity: similarity}
|
||||
}
|
||||
sort.Sort(BySimilarity(similarities))
|
||||
if len(similarities) < n {
|
||||
n = len(similarities)
|
||||
}
|
||||
result := make(VectorSlice, n)
|
||||
for i := 0; i < n; i++ {
|
||||
result[i] = similarities[i].Vector
|
||||
}
|
||||
return result
|
||||
}
|
||||
1
go.mod
1
go.mod
|
|
@ -43,6 +43,7 @@ require (
|
|||
golang.org/x/sys v0.10.0 // indirect
|
||||
golang.org/x/term v0.10.0
|
||||
golang.org/x/text v0.10.0 // indirect
|
||||
gonum.org/v1/gonum v0.13.0
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -131,6 +131,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
|
|||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM=
|
||||
gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
|
|
|
|||
|
|
@ -234,7 +234,7 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s
|
|||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (llm *llama) Embed(input string) ([]float32, error) {
|
||||
func (llm *llama) Embed(input string) ([]float64, error) {
|
||||
if !llm.EmbeddingOnly {
|
||||
return nil, errors.New("llama: embedding not enabled")
|
||||
}
|
||||
|
|
@ -249,7 +249,7 @@ func (llm *llama) Embed(input string) ([]float32, error) {
|
|||
Len: n,
|
||||
Cap: n,
|
||||
}
|
||||
embedSlice := *(*[]float32)(unsafe.Pointer(&header))
|
||||
embedSlice := *(*[]float64)(unsafe.Pointer(&header))
|
||||
|
||||
return embedSlice, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue