refactor rope
change to a flatter directory structure and group the options with the function update models to call rope in one place
This commit is contained in:
committed by
Michael Yang
parent
e082d60a24
commit
603ceefaa6
@@ -23,18 +23,18 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
||||
}
|
||||
|
||||
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
||||
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
||||
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
|
||||
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
|
||||
}
|
||||
|
||||
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor {
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, query.Dim(1))
|
||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
||||
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numHeads, key.Dim(1))
|
||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
||||
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))
|
||||
|
||||
Reference in New Issue
Block a user