Compare commits

..

8 Commits

Author SHA1 Message Date
ParthSareen
0de5bbd0fe sample: use json unmarshal for sampling params 2025-03-20 15:03:42 -04:00
Parth Sareen
42a14f7f63 sample: add error handling for empty logits (#9740) 2025-03-20 11:11:18 -07:00
Patrick Devine
f8c3dbe5b5 templates: add autotemplate for gemma3 (#9880)
This change allows the gemma3 template to be autodetected during `ollama
create`.
2025-03-20 00:15:30 -07:00
Jesse Gross
b078dd157c gemma2: Remove second call to Rows
Looks like a merge conflict that broke the model.
2025-03-19 17:28:49 -07:00
Blake Mizerany
2ddacd7516 server/internal/client/ollama: confirm all chunksums were received (#9893)
If the chunksums response is missing a chunk, the client should fail
the download. This changes the client to check that all bytes are
accounted for in the chunksums response.

It is possible there are overlaps or gaps in the chunksums response and
so the size is not the only thing left to check, but this provides
enough coverage for now. We may want to check that chunks are contiguous
later.
2025-03-19 14:59:57 -07:00
Jeffrey Morgan
da0e345200 ml: use input context for extracting outputs (#9875) 2025-03-18 18:08:19 -07:00
Bruce MacDonald
df94175a0f ggml: return error on failure to read tensor data (#9872)
When converting a ggml model if there is a failure to read tensor data a nil error value was being returned. It should be assigned to the actual error from reading.
2025-03-18 16:51:33 -07:00
Bruce MacDonald
61a8825216 convert: return name of unsupported architecture (#9862)
When a model's architecture cannot be converted return the name of the unsupported arch in the error message.
2025-03-18 10:38:28 -07:00
26 changed files with 361 additions and 290 deletions

View File

@@ -201,7 +201,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
case "CohereForCausalLM":
conv = &commandrModel{}
default:
return errors.New("unsupported architecture")
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {

View File

@@ -11,10 +11,9 @@ import (
"slices"
"strings"
"github.com/d4l3k/go-bfloat16"
"github.com/x448/float16"
"golang.org/x/exp/maps"
"github.com/ollama/ollama/types/bfloat16"
)
type safetensorMetadata struct {

1
go.mod
View File

@@ -16,6 +16,7 @@ require (
require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/google/go-cmp v0.6.0

2
go.sum
View File

@@ -35,6 +35,8 @@ github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARu
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 h1:cBzrdJPAFBsgCrDPnZxlp1dF2+k4r1kVpD7+1S1PVjY=
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLcxEuYUlAd/EXyjc/v55nd3+47YAgWbSXVxPrNI=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

View File

@@ -330,7 +330,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
}
}
if g.Wait() != nil {
if err := g.Wait(); err != nil {
return nil, err
}

View File

@@ -179,7 +179,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
@@ -211,8 +211,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
return hiddenState.Rows(ctx, outputs), nil
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
}
func init() {

View File

@@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}

View File

@@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}

View File

@@ -154,7 +154,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}

View File

@@ -561,14 +561,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
sampler := sample.NewSampler(
req.Options.Temperature,
req.Options.TopK,
req.Options.TopP,
req.Options.MinP,
req.Options.Seed,
grammar,
)
sampler := sample.NewSampler(req.Options, grammar)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.Options.NumPredict,

View File

@@ -1,12 +1,14 @@
package sample
import (
"encoding/json"
"errors"
"math"
"math/rand/v2"
"slices"
"sync"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama"
)
@@ -26,6 +28,10 @@ type Sampler struct {
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
if len(logits) == 0 {
return -1, errors.New("sample: no logits provided to sample")
}
tokens := make([]token, len(logits))
for i := range logits {
tokens[i].id = int32(i)
@@ -94,13 +100,6 @@ func (s *Sampler) sample(tokens []token) (token, error) {
tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP)
// TODO: this should fall back to greedy sampling
// or topP, topK values etc should be such that
// there are always tokens to sample from
if len(tokens) == 0 {
return token{}, errors.New("no tokens to sample from")
}
var r float32
if s.rng != nil {
r = s.rng.Float32()
@@ -123,43 +122,71 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return 1
})
if math.IsNaN(float64(sum)) {
return token{}, errors.New("sample: logits sum to NaN, check model output")
}
return tokens[idx], nil
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
// SamplerParams contains the validated and normalized parameters for a sampler
type SamplerParams struct {
Temperature float32 `json:"temperature"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
Seed int `json:"seed"`
}
// UnmarshalJSON implements json.Unmarshaler to handle validation during JSON unmarshaling
func (p *SamplerParams) UnmarshalJSON(data []byte) error {
type rawParams SamplerParams
if err := json.Unmarshal(data, (*rawParams)(p)); err != nil {
return err
}
// Validate and normalize after unmarshaling
if p.Temperature < 0.0 {
p.Temperature = 0.0
}
if p.TopP < 0.0 {
p.TopP = 0.0
}
if p.TopP >= 1.0 {
p.TopP = 1.0
}
if p.MinP < 0.0 {
p.MinP = 0.0
}
if p.MinP >= 1.0 {
p.MinP = 1.0
}
return nil
}
// NewSampler creates a new sampler with the given options
func NewSampler(opts *api.Options, grammar *Grammar) Sampler {
var params SamplerParams
data, _ := json.Marshal(opts)
_ = json.Unmarshal(data, &params)
var rng *rand.Rand
if seed != -1 {
if params.Seed != -1 {
// PCG requires two parameters: sequence and stream
// Use original seed for sequence
sequence := uint64(seed)
sequence := uint64(params.Seed)
// Use golden ratio hash to generate statistically independent seeds
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
}
if temperature < 0.0 {
temperature = 0.0
}
if topP < 0.0 {
topP = 0.0
}
if topP >= 1.0 {
topP = 1.0
}
if minP < 0.0 {
minP = 0.0
}
if minP >= 1.0 {
minP = 1.0
}
return Sampler{
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
topK: params.TopK,
topP: params.TopP,
minP: params.MinP,
temperature: params.Temperature,
grammar: grammar,
}
}

View File

@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
sampler := NewSampler(createSamplerOptions(0.8, 0, 0, 0, 42), nil)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
for _, tc := range configs {
b.Run("Config"+tc.name, func(b *testing.B) {
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
sampler := NewSampler(createSamplerOptions(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed), nil)
sampler.Sample(logits)
b.ResetTimer()
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
// Test with combined transforms separately - topK influences performance greatly
b.Run("TransformCombined", func(b *testing.B) {
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
sampler := NewSampler(createSamplerOptions(0.8, 50, 0.9, 0.05, 42), nil)
b.ResetTimer()
for b.Loop() {
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0, -1, 0, 0, -1, nil)
sampler := NewSampler(createSamplerOptions(0, -1, 0, 0, -1), nil)
b.ResetTimer()
for b.Loop() {

View File

@@ -1,13 +1,26 @@
package sample
import (
"math"
"math/rand/v2"
"testing"
"github.com/ollama/ollama/api"
)
func createSamplerOptions(temperature float32, topK int, topP float32, minP float32, seed int) *api.Options {
return &api.Options{
Temperature: temperature,
TopK: topK,
TopP: topP,
MinP: minP,
Seed: seed,
}
}
func TestWeighted(t *testing.T) {
logits := []float32{-10, 3, -10, -10}
sampler := NewSampler(0, 0, 0, 0, 0, nil)
sampler := NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
got, err := sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -19,7 +32,7 @@ func TestWeighted(t *testing.T) {
}
logits = []float32{-100, -10, 0, 10}
sampler = NewSampler(0, 0, 0, 0, 0, nil)
sampler = NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -29,12 +42,35 @@ func TestWeighted(t *testing.T) {
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
// Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens
sampler = NewSampler(createSamplerOptions(1.0, 0, 1e-10, 0, 0), nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
return
}
// Should get the token with the highest logit
want = int32(0)
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(createSamplerOptions(1, 0, 0.95, 0.05, 0), nil)
got, err = sampler.Sample(logits)
if err == nil {
t.Errorf("expected error, got %d", got)
return
}
}
func BenchmarkSample(b *testing.B) {
samplers := map[string]Sampler{
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
"Greedy": NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(createSamplerOptions(0.5, 10, 0.9, 0.2, -1), nil),
}
// Generate random logits for benchmarking

View File

@@ -168,27 +168,53 @@ func TestTopP(t *testing.T) {
softmax(tokens)
tokens = topK(tokens, 20)
// Then apply topP
tokens = topP(tokens, 0.95)
// Test with very high p value
got := topP(tokens, 1.0)
// Should keep tokens until cumsum > 0.95
if len(tokens) > 3 {
// Should keep all tokens since p is 1
if len(got) != len(input) {
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
}
// Test with normal p value
got = topP(tokens, 0.95)
if len(got) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", tokens)
t.Logf("got: %v", got)
}
// Test edge case - ensure at least one token remains
input = []float32{-1e6, -1e6, -1e6} // One dominant token
input = []float32{-1e6, -1e6, -1e7}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
tokens = topP(tokens, 0.0) // Very small p
if len(tokens) < 1 {
got = topP(tokens, 0.0)
if len(got) < 1 {
t.Error("topP should keep at least one token")
}
// Test with zero p value
got = topP(tokens, 0.0)
// Should keep only the highest probability token
if len(got) != 1 {
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 1e-10)
if len(got) == 0 {
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
t.Logf("got: %v", got)
}
}
func TestMinP(t *testing.T) {
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
tokens := toTokens(input)
// First apply temperature and softmax
@@ -225,30 +251,48 @@ func TestMinP(t *testing.T) {
t.Logf("got: %v", tokens)
}
// Test with single token
tokens = toTokens(input[:1])
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.1)
// Should keep only the highest probability token
if len(tokens) != 1 {
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
input = []float32{1e-10, 1e-10, 1e-10}
tokens = toTokens(input)
softmax(tokens)
tokens = minP(tokens, 1.0)
if len(tokens) < 1 {
t.Error("minP should keep at least one token even with extreme probabilities")
}
}
got := minP(tokens, 1.0)
func TestSortLogits(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input)
if len(got) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
}
tokens = topK(tokens, 20)
// Test with normal p value
got = minP(tokens, 0.2)
for i := 1; i < len(tokens); i++ {
if tokens[i].value > tokens[i-1].value {
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
i, tokens[i].value, tokens[i-1].value)
// Should keep tokens with prob >= 0.2 * max_prob
if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
t.Logf("got: %v", got)
}
// Test with zero p value
got = minP(tokens, 0.0)
// Should keep only the highest probability token
if len(got) != len(tokens) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
}
}
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
compareLogits(t, "sortLogits", want, tokens)
}
func BenchmarkTransforms(b *testing.B) {

View File

@@ -37,7 +37,6 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/internal/backoff"
"github.com/ollama/ollama/server/internal/internal/names"
_ "embed"
@@ -213,12 +212,6 @@ type Registry struct {
// request. If zero, [DefaultChunkingThreshold] is used.
ChunkingThreshold int64
// MaxChunkSize is the maximum size of a chunk to download. If zero,
// the default is [DefaultMaxChunkSize].
//
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
// Mask, if set, is the name used to convert non-fully qualified names
// to fully qualified names. If empty, [DefaultMask] is used.
Mask string
@@ -447,6 +440,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
if err != nil {
return err
}
// TODO(bmizerany): decide if this should be considered valid. Maybe
// server-side we special case '{}' to have some special meaning? Maybe
// "archiving" a tag (which is how we reason about it in the registry
// already, just with a different twist).
if len(m.Layers) == 0 {
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
}
@@ -456,11 +454,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err
}
exists := func(l *Layer) bool {
info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size
}
// TODO(bmizerany): work to remove the need to do this
layers := m.Layers
if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config)
@@ -469,19 +463,16 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// Send initial layer trace events to allow clients to have an
// understanding of work to be done before work starts.
t := traceFromContext(ctx)
skip := make([]bool, len(layers))
for i, l := range layers {
for _, l := range layers {
t.update(l, 0, nil)
if exists(l) {
skip[i] = true
t.update(l, l.Size, ErrCached)
}
}
g, ctx := errgroup.WithContext(ctx)
var g errgroup.Group
g.SetLimit(r.maxStreams())
for i, l := range layers {
if skip[i] {
for _, l := range layers {
info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size {
t.update(l, l.Size, ErrCached)
continue
}
@@ -490,63 +481,50 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
t.update(l, 0, err)
continue
}
// TODO(bmizerany): fix this unbounded use of defer
defer chunked.Close()
var progress atomic.Int64
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
// Bad chunksums response, update tracing
// clients and then bail.
t.update(l, progress.Load(), err)
break
return err
}
g.Go(func() (err error) {
defer func() { t.update(l, progress.Load(), err) }()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
defer func() {
if err != nil {
return err
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
}
err := func() error {
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
t.update(l, progress.Load(), err)
}()
// Count bytes towards
// progress, as they arrive, so
// that our bytes piggyback
// other chunk updates on
// completion.
//
// This tactic is enough to
// show "smooth" progress given
// the current CLI client. In
// the near future, the server
// should report download rate
// since it knows better than
// a client that is measuring
// rate based on wall-clock
// time-since-last-update.
body := &trackingReader{r: res.Body, n: &progress}
err = chunked.Put(cs.Chunk, cs.Digest, body)
if err != nil {
return err
}
return nil
}()
if !canRetry(err) {
return err
}
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
return nil
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
// Count bytes towards progress, as they
// arrive, so that our bytes piggyback other
// chunk updates on completion.
//
// This tactic is enough to show "smooth"
// progress given the current CLI client. In
// the near future, the server should report
// download rate since it knows better than a
// client that is measuring rate based on
// wall-clock time-since-last-update.
body := &trackingReader{r: res.Body, n: &progress}
return chunked.Put(cs.Chunk, cs.Digest, body)
})
}
}
@@ -554,13 +532,10 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err
}
// store the manifest blob
md := blob.DigestFromBytes(m.Data)
if err := blob.PutBytes(c, md, m.Data); err != nil {
return err
}
// commit the manifest with a link
return c.Link(m.Name, md)
}
@@ -782,12 +757,15 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
}
blobURL := res.Header.Get("Content-Location")
var size int64
s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords)
for {
if !s.Scan() {
if s.Err() != nil {
yield(chunksum{}, s.Err())
} else if size != l.Size {
yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size))
}
return
}
@@ -811,6 +789,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
return
}
size += chunk.Size()
if size > l.Size {
yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size))
return
}
cs := chunksum{
URL: blobURL,
Chunk: chunk,

View File

@@ -17,6 +17,7 @@ import (
"reflect"
"slices"
"strings"
"sync"
"testing"
"time"
@@ -56,21 +57,21 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
// newClient constructs a cache with predefined manifests for testing. The manifests are:
//
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
//
// Tests that want to ensure the client does not communicate with the upstream
// registry should pass a nil handler, which will cause a panic if
// communication is attempted.
//
// To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper()
c, err := blob.Open(t.TempDir())
@@ -88,7 +89,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
r := &Registry{
Cache: c,
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
Transport: recordRoundTripper(upstreamRegistry),
},
}
@@ -767,3 +768,74 @@ func TestUnlink(t *testing.T) {
}
})
}
func TestPullChunksums(t *testing.T) {
check := testutil.Checker(t)
content := "hello"
var chunksums string
contentDigest := func() blob.Digest {
return blob.DigestFromBytes(content)
}
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/manifests/latest"):
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
w.Header().Set("Content-Location", loc)
io.WriteString(w, chunksums)
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
default:
t.Errorf("unexpected request: %v", r)
http.NotFound(w, r)
}
})
rc.MaxStreams = 1 // prevent concurrent chunk downloads
rc.ChunkingThreshold = 1 // for all blobs to be chunked
var mu sync.Mutex
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("Update: %v %d %v", l, n, err)
mu.Lock()
reads = append(reads, n)
mu.Unlock()
},
})
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
blob.DigestFromBytes("hel"),
blob.DigestFromBytes("lo"),
)
err := rc.Pull(ctx, "test")
check(err)
if !slices.Equal(reads, []int64{0, 3, 5}) {
t.Errorf("reads = %v; want %v", reads, []int64{0, 3, 5})
}
mw, err := rc.Resolve(t.Context(), "test")
check(err)
mg, err := rc.ResolveLocal("test")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
for i := range mg.Layers {
_, err = c.Get(mg.Layers[i].Digest)
if err != nil {
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
}
}
// missing chunks
content = "llama"
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
err = rc.Pull(ctx, "missingchunks")
if err == nil {
t.Error("expected error because of missing chunks")
}
}

View File

@@ -82,7 +82,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
for _, layer := range layers {
if s := layer.GGML.KV().ChatTemplate(); s != "" {
if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err)
slog.Debug("template detection", "error", err, "template", s)
} else {
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil {

View File

@@ -0,0 +1,13 @@
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 }}
{{- if eq .Role "user" }}<start_of_turn>user
{{- if and (eq $i 1) $.System }}
{{ $.System }}
{{ end }}
{{ .Content }}<end_of_turn>
{{ else if eq .Role "assistant" }}<start_of_turn>model
{{ .Content }}<end_of_turn>
{{ end }}
{{- if $last }}<start_of_turn>model
{{ end }}
{{- end }}

View File

@@ -0,0 +1,6 @@
{
"stop": [
"<end_of_turn>"
],
"temperature": 0.1
}

View File

@@ -87,6 +87,10 @@
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"name": "gemma-instruct"
},
{
"template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
"name": "gemma3-instruct"
},
{
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"name": "llama3-instruct"

View File

@@ -0,0 +1,10 @@
<start_of_turn>user
You are a helpful assistant.
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model

View File

@@ -0,0 +1,4 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model

View File

@@ -0,0 +1,8 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model

View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2021 Tristan Rice
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,57 +0,0 @@
// Vendored code from https://github.com/d4l3k/go-bfloat16
// unsafe pointer replaced by "math"
package bfloat16
import "math"
type BF16 uint16
func FromBytes(buf []byte) BF16 {
return BF16(uint16(buf[0]) + uint16(buf[1])<<8)
}
func ToBytes(b BF16) []byte {
return []byte{byte(b & 0xFF), byte(b >> 8)}
}
func Decode(buf []byte) []BF16 {
var out []BF16
for i := 0; i < len(buf); i += 2 {
out = append(out, FromBytes(buf[i:]))
}
return out
}
func Encode(f []BF16) []byte {
var out []byte
for _, a := range f {
out = append(out, ToBytes(a)...)
}
return out
}
func DecodeFloat32(buf []byte) []float32 {
var out []float32
for i := 0; i < len(buf); i += 2 {
out = append(out, ToFloat32(FromBytes(buf[i:])))
}
return out
}
func EncodeFloat32(f []float32) []byte {
var out []byte
for _, a := range f {
out = append(out, ToBytes(FromFloat32(a))...)
}
return out
}
func ToFloat32(b BF16) float32 {
u32 := uint32(b) << 16
return math.Float32frombits(u32)
}
func FromFloat32(f float32) BF16 {
u32 := math.Float32bits(f)
return BF16(u32 >> 16)
}

View File

@@ -1,53 +0,0 @@
package bfloat16
import (
"crypto/rand"
"reflect"
"testing"
)
func randomBytes(n int) []byte {
out := make([]byte, n)
if _, err := rand.Read(out); err != nil {
panic(err)
}
return out
}
func TestEncodeDecode(t *testing.T) {
b := randomBytes(1024)
bf16 := Decode(b)
out := Encode(bf16)
if !reflect.DeepEqual(b, out) {
t.Fatalf("%+v != %+v", b, out)
}
}
func TestEncodeDecodeFloat32(t *testing.T) {
b := randomBytes(1024)
bf16 := DecodeFloat32(b)
out := EncodeFloat32(bf16)
if !reflect.DeepEqual(b, out) {
t.Fatalf("%+v != %+v", b, out)
}
}
func TestBasicFloat32(t *testing.T) {
var in float32 = 1.0
out := ToFloat32(FromFloat32(in))
if !reflect.DeepEqual(in, out) {
t.Fatalf("%+v != %+v", in, out)
}
}
func TestComplexFloat32(t *testing.T) {
var in float32 = 123456789123456789.123456789
var want float32 = 123286039799267328.0
out := ToFloat32(FromFloat32(in))
if in == out {
t.Fatalf("no loss of precision")
}
if out != want {
t.Fatalf("%.16f != %.16f", want, out)
}
}