prefer tokenizer.json
check tokenizer decoder for tokenizer type
This commit is contained in:
parent
1c094038bc
commit
bddcb3fb16
|
|
@ -204,12 +204,14 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||||
|
|
||||||
type tokenizer struct {
|
type tokenizer struct {
|
||||||
AddedTokens []token `json:"added_tokens"`
|
AddedTokens []token `json:"added_tokens"`
|
||||||
Model struct {
|
Decoder struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
} `json:"decoder"`
|
||||||
|
Model struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Vocab map[string]int `json:"vocab"`
|
Vocab map[string]int `json:"vocab"`
|
||||||
Merges json.RawMessage `json:"merges"`
|
Merges json.RawMessage `json:"merges"`
|
||||||
} `json:"model"`
|
} `json:"model"`
|
||||||
|
|
||||||
PreTokenizer struct {
|
PreTokenizer struct {
|
||||||
PreTokenizers []struct {
|
PreTokenizers []struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
@ -246,6 +248,11 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model := "gpt2"
|
||||||
|
if t.Decoder.Type == "SentencePiece" {
|
||||||
|
model = "llama"
|
||||||
|
}
|
||||||
|
|
||||||
tokens := make(map[int]token, len(t.Model.Vocab))
|
tokens := make(map[int]token, len(t.Model.Vocab))
|
||||||
for k, v := range t.Model.Vocab {
|
for k, v := range t.Model.Vocab {
|
||||||
tokens[v] = token{
|
tokens[v] = token{
|
||||||
|
|
@ -259,7 +266,7 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
|
||||||
tokens[token.ID] = token
|
tokens[token.ID] = token
|
||||||
}
|
}
|
||||||
|
|
||||||
v := Vocabulary{Model: "gpt2"}
|
v := Vocabulary{Model: model}
|
||||||
for _, k := range slices.Sorted(maps.Keys(tokens)) {
|
for _, k := range slices.Sorted(maps.Keys(tokens)) {
|
||||||
token := tokens[k]
|
token := tokens[k]
|
||||||
v.Tokens = append(v.Tokens, token.Content)
|
v.Tokens = append(v.Tokens, token.Content)
|
||||||
|
|
@ -283,8 +290,8 @@ func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
|
||||||
Pattern string
|
Pattern string
|
||||||
Func func(fs.FS) (*Vocabulary, error)
|
Func func(fs.FS) (*Vocabulary, error)
|
||||||
}{
|
}{
|
||||||
{"tokenizer.model", parseSentencePiece},
|
|
||||||
{"tokenizer.json", parseVocabularyFromTokenizer},
|
{"tokenizer.json", parseVocabularyFromTokenizer},
|
||||||
|
{"tokenizer.model", parseSentencePiece},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, pattern := range patterns {
|
for _, pattern := range patterns {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue