diff --git a/model/model.go b/model/model.go index 5493a4e63..f3d6bb3db 100644 --- a/model/model.go +++ b/model/model.go @@ -107,23 +107,12 @@ func New(modelPath string, params ml.BackendParams) (Model, error) { return nil, err } - arch := b.Config().Architecture() - if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone { - arch = arch + "_embed" - } - - f, ok := models[arch] - if !ok { - return nil, fmt.Errorf("unsupported model architecture %q", arch) - } - - m, err := f(b.Config()) + m, err := modelForArch(b.Config()) if err != nil { return nil, err } base := Base{b: b, config: m.Config()} - v := reflect.ValueOf(m) v.Elem().Set(populateFields(base, v.Elem())) return m, nil @@ -135,30 +124,38 @@ func NewTextProcessor(s string) (TextProcessor, error) { return nil, err } defer r.Close() + meta, err := fsggml.Decode(r, -1) if err != nil { return nil, err } - return getTextProcessor(meta.KV()) -} -func getTextProcessor(kv fsggml.KV) (TextProcessor, error) { - arch := kv.Architecture() - f, ok := models[arch] - if !ok { - return nil, fmt.Errorf("unsupported model architecture %q", arch) - } - m, err := f(kv) + m, err := modelForArch(meta.KV()) if err != nil { return nil, err } + tp, ok := m.(TextProcessor) if !ok { - return nil, fmt.Errorf("%v is not a TextProcessor", m) + return nil, ErrUnsupportedTokenizer } return tp, nil } +func modelForArch(c fs.Config) (Model, error) { + arch := c.Architecture() + if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone { + arch = arch + "_embed" + } + + f, ok := models[arch] + if !ok { + return nil, ErrUnsupportedModel + } + + return f(c) +} + func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { t := v.Type() diff --git a/model/model_test.go b/model/model_test.go index 020f9ffbd..01080ffdf 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -1,9 +1,9 @@ package model import ( + "errors" "reflect" "slices" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -12,7 +12,6 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/model/input" ) func TestParseTags(t *testing.T) { @@ -148,39 +147,58 @@ func TestPopulateFieldsAlternateName(t *testing.T) { } } -func TestGetTextProcessor(t *testing.T) { - tp, err := getTextProcessor(fsggml.KV{}) - if err == nil { - t.Error("expected error") - } else if !strings.Contains(err.Error(), "unsupported model architecture") { - t.Errorf("unexpected error: %v", err) - } else if tp != nil { - t.Error("expected nil tp") +func TestModelForArch(t *testing.T) { + type fakeModel struct { + Model } - models["dummy"] = func(fs.Config) (Model, error) { - return notTextProcessorModel{}, nil + type fakeEmbeddingModel struct { + Model } - tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"}) - if err == nil { - t.Error("expected error") - } else if !strings.Contains(err.Error(), "not a TextProcessor") { - t.Errorf("unexpected error: %v", err) - } else if tp != nil { - t.Error("expected nil tp") + + models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil } + models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil } + + cases := []struct { + name string + config fs.Config + want any + err error + }{ + { + name: "model", + config: fsggml.KV{ + "general.architecture": "model", + }, + want: fakeModel{}, + }, + { + name: "embedding", + config: fsggml.KV{ + "general.architecture": "model", + "model.pooling_type": uint32(1), + }, + want: fakeEmbeddingModel{}, + }, + { + name: "unsupported", + config: fsggml.KV{ + "general.architecture": "unsupported", + }, + err: ErrUnsupportedModel, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := modelForArch(tt.config) + if !errors.Is(err, tt.err) { + t.Fatal(err) + } + + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff) + } + }) } } - -type notTextProcessorModel struct{} - -func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) { - panic("unimplemented") -} - -func (notTextProcessorModel) Backend() ml.Backend { - panic("unimplemented") -} - -func (notTextProcessorModel) Config() config { - panic("unimplemented") -}