package gguf import ( "bytes" "cmp" "encoding/binary" "errors" "fmt" "io" "iter" "os" "slices" "strings" ) const ( typeUint8 uint32 = iota typeInt8 typeUint16 typeInt16 typeUint32 typeInt32 typeFloat32 typeBool typeString typeArray typeUint64 typeInt64 typeFloat64 ) var ErrUnsupported = errors.New("unsupported") type File struct { Magic [4]byte Version uint32 keyValues *lazy[KeyValue] tensorInfos *lazy[TensorInfo] offset int64 n uint64 file *os.File reader *bufferedReader bts []byte } func Open(path string) (f *File, err error) { f = &File{bts: make([]byte, 4096)} f.file, err = os.Open(path) if err != nil { return nil, err } f.reader = newBufferedReader(f.file, 32<<10) if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil { return nil, err } if bytes.Equal(f.Magic[:], []byte("gguf")) { return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic) } if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil { return nil, err } if f.Version < 2 { return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version) } f.tensorInfos, err = newLazy(f, f.readTensor) if err != nil { return nil, err } f.tensorInfos.successFunc = func() error { offset := f.reader.offset alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32) f.offset = offset + (alignment-offset%alignment)%alignment return nil } f.keyValues, err = newLazy(f, f.readKeyValue) if err != nil { return nil, err } return f, nil } func (f *File) readTensor() (TensorInfo, error) { name, err := readString(f) if err != nil { return TensorInfo{}, err } dims, err := read[uint32](f) if err != nil { return TensorInfo{}, err } shape := make([]uint64, dims) for i := range dims { shape[i], err = read[uint64](f) if err != nil { return TensorInfo{}, err } } type_, err := read[uint32](f) if err != nil { return TensorInfo{}, err } offset, err := read[uint64](f) if err != nil { return TensorInfo{}, err } tensorInfo := TensorInfo{ Name: name, Offset: offset, Shape: shape, Type: TensorType(type_), } f.n += tensorInfo.NumValues() return tensorInfo, nil } func (f *File) readKeyValue() (KeyValue, error) { key, err := readString(f) if err != nil { return KeyValue{}, err } t, err := read[uint32](f) if err != nil { return KeyValue{}, err } value, err := func() (any, error) { switch t { case typeUint8: return read[uint8](f) case typeInt8: return read[int8](f) case typeUint16: return read[uint16](f) case typeInt16: return read[int16](f) case typeUint32: return read[uint32](f) case typeInt32: return read[int32](f) case typeUint64: return read[uint64](f) case typeInt64: return read[int64](f) case typeFloat32: return read[float32](f) case typeFloat64: return read[float64](f) case typeBool: return read[bool](f) case typeString: return readString(f) case typeArray: return readArray(f) default: return nil, fmt.Errorf("%w type %d", ErrUnsupported, t) } }() if err != nil { return KeyValue{}, err } return KeyValue{ Key: key, Value: Value{value}, }, nil } func read[T any](f *File) (t T, err error) { err = binary.Read(f.reader, binary.LittleEndian, &t) return t, err } func readString(f *File) (string, error) { bts := f.bts[:8] if _, err := io.ReadFull(f.reader, bts); err != nil { return "", err } n := binary.LittleEndian.Uint64(bts) if int(n) > len(f.bts) { f.bts = make([]byte, n) } bts = f.bts[:n] if _, err := io.ReadFull(f.reader, bts); err != nil { return "", err } return string(bts), nil } func readArray(f *File) (any, error) { t, err := read[uint32](f) if err != nil { return nil, err } n, err := read[uint64](f) if err != nil { return nil, err } switch t { case typeUint8: return readArrayData[uint8](f, n) case typeInt8: return readArrayData[int8](f, n) case typeUint16: return readArrayData[uint16](f, n) case typeInt16: return readArrayData[int16](f, n) case typeUint32: return readArrayData[uint32](f, n) case typeInt32: return readArrayData[int32](f, n) case typeUint64: return readArrayData[uint64](f, n) case typeInt64: return readArrayData[int64](f, n) case typeFloat32: return readArrayData[float32](f, n) case typeFloat64: return readArrayData[float64](f, n) case typeBool: return readArrayData[bool](f, n) case typeString: return readArrayString(f, n) default: return nil, fmt.Errorf("%w type %d", ErrUnsupported, t) } } func readArrayData[T any](f *File, n uint64) (*lazy[T], error) { offset := f.reader.offset var t T if _, err := f.reader.Discard(int(n) * binary.Size(t)); err != nil { return nil, err } sr := io.NewSectionReader(f.file, offset, int64(int(n)*binary.Size(t))) next, stop := iter.Pull(func(yield func(T) bool) { s := make([]T, n) if err := binary.Read(sr, binary.LittleEndian, &s); err != nil { return } for _, e := range s { if !yield(e) { return } } }) return &lazy[T]{count: n, next: next, stop: stop}, nil } func readArrayString(f *File, n uint64) (*lazy[string], error) { offset := f.reader.offset var size int64 for range n { bts := f.bts[:8] if _, err := io.ReadFull(f.reader, bts); err != nil { return nil, err } n := int(binary.LittleEndian.Uint64(bts)) if _, err := f.reader.Discard(n); err != nil { return nil, err } size += 8 + int64(n) } sr := io.NewSectionReader(f.file, offset, size) next, stop := iter.Pull(func(yield func(string) bool) { f := File{reader: newBufferedReader(sr, 16<<10), bts: make([]byte, 4096)} for range n { s, err := readString(&f) if err != nil { return } if !yield(s) { return } } }) return &lazy[string]{count: n, next: next, stop: stop}, nil } func (f *File) Close() error { f.keyValues.stop() f.tensorInfos.stop() return f.file.Close() } func (f *File) KeyValue(key string) KeyValue { if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") { key = f.KeyValue("general.architecture").String() + "." + key } if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool { return kv.Key == key }); index >= 0 { return f.keyValues.values[index] } for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() { if keyValue.Key == key { return keyValue } } return KeyValue{} } func (f *File) NumKeyValues() int { return int(f.keyValues.count) } func (f *File) KeyValues() iter.Seq2[int, KeyValue] { return f.keyValues.All() } func (f *File) TensorInfo(name string) TensorInfo { if index := slices.IndexFunc(f.tensorInfos.values, func(t TensorInfo) bool { return t.Name == name }); index >= 0 { return f.tensorInfos.values[index] } // fast-forward through key values if we haven't already _ = f.keyValues.rest() for tensor, ok := f.tensorInfos.next(); ok; tensor, ok = f.tensorInfos.next() { if tensor.Name == name { return tensor } } return TensorInfo{} } func (f *File) NumTensors() int { return int(f.tensorInfos.count) } func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] { // fast forward through key values if we haven't already _ = f.keyValues.rest() return f.tensorInfos.All() } func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) { t := f.TensorInfo(name) if t.NumBytes() == 0 { return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name) } // fast forward through tensor info if we haven't already _ = f.tensorInfos.rest() return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), int64(t.NumBytes())), nil } func (f *File) NumValues() uint64 { _ = f.tensorInfos.rest() return f.n }