diff --git a/convert/tokenizer.go b/convert/tokenizer.go index 41d0310a0..ccacbda85 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -204,12 +204,14 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) type tokenizer struct { AddedTokens []token `json:"added_tokens"` - Model struct { + Decoder struct { + Type string `json:"type"` + } `json:"decoder"` + Model struct { Type string `json:"type"` Vocab map[string]int `json:"vocab"` Merges json.RawMessage `json:"merges"` } `json:"model"` - PreTokenizer struct { PreTokenizers []struct { Type string `json:"type"` @@ -246,6 +248,11 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) { return nil, err } + model := "gpt2" + if t.Decoder.Type == "SentencePiece" { + model = "llama" + } + tokens := make(map[int]token, len(t.Model.Vocab)) for k, v := range t.Model.Vocab { tokens[v] = token{ @@ -259,7 +266,7 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) { tokens[token.ID] = token } - v := Vocabulary{Model: "gpt2"} + v := Vocabulary{Model: model} for _, k := range slices.Sorted(maps.Keys(tokens)) { token := tokens[k] v.Tokens = append(v.Tokens, token.Content) @@ -283,8 +290,8 @@ func parseVocabulary(fsys fs.FS) (*Vocabulary, error) { Pattern string Func func(fs.FS) (*Vocabulary, error) }{ - {"tokenizer.model", parseSentencePiece}, {"tokenizer.json", parseVocabularyFromTokenizer}, + {"tokenizer.model", parseSentencePiece}, } for _, pattern := range patterns {