staticcheck
This commit is contained in:
@@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.attnKeyLen, m.ropeBase, 1/m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
@@ -178,10 +178,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
||||
|
||||
if len(m.Layers) == gemma27BLayerCount {
|
||||
m.Options.largeModelScaling = true
|
||||
m.largeModelScaling = true
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
@@ -202,9 +202,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
|
||||
// final logit softcap
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
|
||||
return hiddenState.Scale(ctx, float64(m.finalLogitSoftcap)), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -96,15 +96,15 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, err := m.ImageProcessor.ProcessImage(image)
|
||||
f32s, err := m.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
m.imageSize,
|
||||
m.imageSize,
|
||||
m.numChannels,
|
||||
)
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
|
||||
@@ -111,12 +111,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase := m.TextConfig.ropeLocalBase
|
||||
ropeBase := m.ropeLocalBase
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = m.TextConfig.ropeGlobalBase
|
||||
ropeBase = m.ropeGlobalBase
|
||||
}
|
||||
|
||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.attnKeyLen, ropeBase, 1/m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
@@ -166,7 +166,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
||||
|
||||
// set image embeddings
|
||||
var except []int
|
||||
|
||||
@@ -53,7 +53,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
MultiModalProjector: newMultiModalProjector(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
@@ -109,12 +109,12 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, size, err := m.ImageProcessor.ProcessImage(image)
|
||||
f32s, size, err := m.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
|
||||
pixelValues := ctx.Input().FromFloats(f32s, size.X, size.Y, m.numChannels)
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
|
||||
|
||||
@@ -133,7 +133,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps)
|
||||
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
|
||||
// Prepare position IDs for 2D rope
|
||||
positions := make([]int32, numPatches)
|
||||
|
||||
@@ -54,7 +54,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
|
||||
encoderCache := kvcache.NewEncoderCache()
|
||||
encoderCache.SetConfig(ml.CacheConfig{})
|
||||
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
|
||||
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.Shift))
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
@@ -69,7 +69,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, ratio, err := m.ImageProcessor.ProcessImage(image)
|
||||
f32s, ratio, err := m.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
@@ -59,14 +59,13 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
|
||||
f32s, grid, err := m.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Calculate tensor dimensions
|
||||
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
|
||||
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
||||
patchDim := m.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
|
||||
|
||||
@@ -228,7 +228,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
||||
|
||||
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
|
||||
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.numHeads)
|
||||
// Apply encoder layers
|
||||
for i, layer := range m.Layers {
|
||||
if slices.Contains(m.fullAttnBlocks, int32(i)) {
|
||||
|
||||
@@ -203,7 +203,7 @@ func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
var _ model.Model = (*Model)(nil)
|
||||
|
||||
Reference in New Issue
Block a user