Compare commits
3 Commits
grace/mist
...
v0.13.5-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7325791599 | ||
|
|
522c11a763 | ||
|
|
0fadeffaee |
@@ -216,8 +216,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &deepseekocr{}
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
case "MistralForCausalLM":
|
||||
conv = &mistralLarge3Model{}
|
||||
default:
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
@@ -1,286 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type mistralLarge3Model struct {
|
||||
ModelParameters
|
||||
Dim uint32 `json:"dim"`
|
||||
NumLayers uint32 `json:"n_layers"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenDim uint32 `json:"hidden_dim"`
|
||||
NumHeads uint32 `json:"n_heads"`
|
||||
NumKVHeads uint32 `json:"n_kv_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
NormEps float32 `json:"norm_eps"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
TiedEmbeddings bool `json:"tied_embeddings"`
|
||||
MaxPosEmbed uint32 `json:"max_position_embeddings"`
|
||||
MaxSeqLen uint32 `json:"max_seq_len"`
|
||||
|
||||
// LoRA attention parameters (DeepSeek-style)
|
||||
QLoraRank uint32 `json:"q_lora_rank"`
|
||||
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
|
||||
KVLoraRank uint32 `json:"kv_lora_rank"`
|
||||
VHeadDim uint32 `json:"v_head_dim"`
|
||||
|
||||
// ROPE scaling configurations
|
||||
Llama4Scaling struct {
|
||||
OrigMaxPosEmbed uint32 `json:"original_max_position_embeddings"`
|
||||
Beta float32 `json:"beta"`
|
||||
} `json:"llama_4_scaling"`
|
||||
|
||||
Yarn struct {
|
||||
OrigMaxPosEmbed uint32 `json:"original_max_position_embeddings"`
|
||||
Factor float32 `json:"factor"`
|
||||
ApplyScale bool `json:"apply_scale"`
|
||||
Beta float32 `json:"beta"`
|
||||
Alpha float32 `json:"alpha"`
|
||||
} `json:"yarn"`
|
||||
|
||||
// MOE configuration
|
||||
MOE struct {
|
||||
ExpertParallel uint32 `json:"expert_parallel"`
|
||||
ExpertModelParallel uint32 `json:"expert_model_parallel"`
|
||||
RouteEveryN uint32 `json:"route_every_n"`
|
||||
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
|
||||
NumExpertGroups uint32 `json:"num_expert_groups"`
|
||||
NumExpertGroupsPerTok uint32 `json:"num_expert_groups_per_tok"`
|
||||
RoutedScale float32 `json:"routed_scale"`
|
||||
ExpertHiddenDim uint32 `json:"expert_hidden_dim"`
|
||||
NumSharedExperts uint32 `json:"num_shared_experts"`
|
||||
} `json:"moe"`
|
||||
|
||||
// Vision encoder configuration
|
||||
VisionEncoder struct {
|
||||
ImageTokenID uint32 `json:"image_token_id"`
|
||||
ImageBreakTokenID uint32 `json:"image_break_token_id"`
|
||||
ImageEndTokenID uint32 `json:"image_end_token_id"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
MMProjectorID string `json:"mm_projector_id"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
MaxImageSize uint32 `json:"max_image_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
AddPreMMProjectorLayerNorm bool `json:"add_pre_mm_projector_layer_norm"`
|
||||
AdapterBias bool `json:"adapter_bias"`
|
||||
} `json:"vision_encoder"`
|
||||
}
|
||||
|
||||
func (p *mistralLarge3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseek2" // Use deepseek2 architecture for runtime compatibility
|
||||
kv["general.type"] = "model"
|
||||
|
||||
// Basic model parameters (using deepseek2 keys for compatibility)
|
||||
kv["deepseek2.vocab_size"] = p.VocabSize
|
||||
kv["deepseek2.block_count"] = p.NumLayers
|
||||
kv["deepseek2.context_length"] = cmp.Or(p.MaxPosEmbed, p.MaxSeqLen)
|
||||
kv["deepseek2.embedding_length"] = p.Dim
|
||||
kv["deepseek2.feed_forward_length"] = p.HiddenDim
|
||||
|
||||
// Attention configuration
|
||||
kv["deepseek2.attention.head_count"] = p.NumHeads
|
||||
kv["deepseek2.attention.head_count_kv"] = p.NumKVHeads
|
||||
kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.NormEps
|
||||
kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
|
||||
kv["deepseek2.attention.value_length"] = p.VHeadDim
|
||||
|
||||
// LoRA attention parameters
|
||||
kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank
|
||||
kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank
|
||||
|
||||
// ROPE configuration
|
||||
kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim
|
||||
kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0)
|
||||
|
||||
// ROPE scaling - map to deepseek2 format
|
||||
if p.Yarn.OrigMaxPosEmbed > 0 {
|
||||
kv["deepseek2.rope.scaling.factor"] = p.Yarn.Factor
|
||||
kv["deepseek2.rope.scaling.original_context_length"] = p.Yarn.OrigMaxPosEmbed
|
||||
kv["deepseek2.rope.scaling.type"] = "yarn"
|
||||
kv["deepseek2.rope.scaling.yarn_log_multiplier"] = float32(0.1) // mscale_all_dim * 0.1 as in llama.cpp
|
||||
}
|
||||
|
||||
// MOE configuration
|
||||
if p.MOE.NumExperts > 0 {
|
||||
kv["deepseek2.expert_count"] = p.MOE.NumExperts
|
||||
kv["deepseek2.expert_used_count"] = p.MOE.NumExpertsPerTok
|
||||
kv["deepseek2.expert_shared_count"] = p.MOE.NumSharedExperts
|
||||
kv["deepseek2.expert_feed_forward_length"] = p.MOE.ExpertHiddenDim
|
||||
kv["deepseek2.expert_weights_scale"] = p.MOE.RoutedScale
|
||||
kv["deepseek2.leading_dense_block_count"] = p.MOE.FirstKDenseReplace
|
||||
kv["deepseek2.expert_weights_norm"] = true
|
||||
kv["deepseek2.expert_gating_func"] = uint32(1) // softmax
|
||||
}
|
||||
|
||||
// Vision encoder configuration (if supported by deepseek2 runtime)
|
||||
if p.VisionEncoder.HiddenSize > 0 {
|
||||
kv["deepseek2.vision.block_count"] = p.VisionEncoder.NumHiddenLayers
|
||||
kv["deepseek2.vision.embedding_length"] = p.VisionEncoder.HiddenSize
|
||||
kv["deepseek2.vision.feed_forward_length"] = p.VisionEncoder.IntermediateSize
|
||||
kv["deepseek2.vision.attention.head_count"] = p.VisionEncoder.NumAttentionHeads
|
||||
kv["deepseek2.vision.image_size"] = p.VisionEncoder.ImageSize
|
||||
kv["deepseek2.vision.patch_size"] = p.VisionEncoder.PatchSize
|
||||
kv["deepseek2.vision.num_channels"] = p.VisionEncoder.NumChannels
|
||||
|
||||
// Multimodal configuration
|
||||
kv["deepseek2.image_token_id"] = p.VisionEncoder.ImageTokenID
|
||||
kv["deepseek2.image_break_token_id"] = p.VisionEncoder.ImageBreakTokenID
|
||||
kv["deepseek2.image_end_token_id"] = p.VisionEncoder.ImageEndTokenID
|
||||
kv["deepseek2.spatial_merge_size"] = p.VisionEncoder.SpatialMergeSize
|
||||
}
|
||||
|
||||
// Set tokenizer type - use tekken preprocessing (now supported!)
|
||||
kv["tokenizer.ggml.pre"] = "tekken"
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *mistralLarge3Model) specialTokenTypes() []string {
|
||||
return []string{
|
||||
"bos", "eos", "unk", "sep", "pad", "cls", "mask",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *mistralLarge3Model) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"tok_embeddings", "token_embd", // Mistral Large uses tok_embeddings instead of model.embed_tokens
|
||||
"norm", "output_norm",
|
||||
"language_model.", "",
|
||||
"layers", "blk", // Mistral 3 Large uses "layers" instead of "model.layers"
|
||||
"attention_norm", "attn_norm",
|
||||
|
||||
// LoRA attention mappings (Mistral 3 Large style)
|
||||
"attention.wkv_a_with_mqa", "attn_kv_a_mqa",
|
||||
"attention.kv_a_norm", "attn_kv_a_norm",
|
||||
"attention.wkv_b", "attn_kv_b",
|
||||
"attention.wq_a", "attn_q_a",
|
||||
"attention.q_a_norm", "attn_q_a_norm",
|
||||
"attention.wq_b", "attn_q_b",
|
||||
"attention.wo", "attn_output",
|
||||
|
||||
"ffn_norm", "ffn_norm", // Keep ffn_norm as is
|
||||
|
||||
// MOE mappings for Mistral 3 Large
|
||||
"shared_experts.w2", "ffn_down_shexp",
|
||||
"shared_experts.w1", "ffn_gate_shexp",
|
||||
"shared_experts.w3", "ffn_up_shexp",
|
||||
"experts.*.w1", "ffn_gate_exps", // Will be merged in Tensors()
|
||||
"experts.*.w2", "ffn_down_exps", // Will be merged in Tensors()
|
||||
"experts.*.w3", "ffn_up_exps", // Will be merged in Tensors()
|
||||
"gate", "ffn_gate_inp",
|
||||
|
||||
// Standard feed forward mappings (for non-MOE layers)
|
||||
"feed_forward.w1", "ffn_gate",
|
||||
"feed_forward.w2", "ffn_down",
|
||||
"feed_forward.w3", "ffn_up",
|
||||
|
||||
// Mistral-specific tensor renaming
|
||||
".qscale_act", ".input_scale",
|
||||
".qscale_weight", ".weight_scale",
|
||||
|
||||
// Vision encoder mappings - do we even need this?
|
||||
"vision_tower", "v",
|
||||
"ln_pre", "encoder_norm",
|
||||
"attention.q_proj", "attn_q",
|
||||
"attention.k_proj", "attn_k",
|
||||
"attention.v_proj", "attn_v",
|
||||
"attention.o_proj", "attn_output",
|
||||
"attention_norm", "attn_norm",
|
||||
"feed_forward.gate_proj", "ffn_gate",
|
||||
"feed_forward.down_proj", "ffn_down",
|
||||
"feed_forward.up_proj", "ffn_up",
|
||||
|
||||
"multi_modal_projector", "mm",
|
||||
"patch_merger.merging_layer", "mm.patch_merger",
|
||||
"pre_mm_projector_norm", "mm.pre_norm",
|
||||
"vision_language_adapter.w_in", "mm.w_in",
|
||||
"vision_language_adapter.w_out", "mm.w_out",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *mistralLarge3Model) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
// Create merges for MOE expert tensors
|
||||
if p.MOE.NumExperts > 0 {
|
||||
merges := make([]merge, p.NumLayers*3)
|
||||
for i := range p.NumLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.experts.*.w1.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.experts.*.w3.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.experts.*.w2.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
out, s = mergeTensors(s, merges...)
|
||||
}
|
||||
|
||||
skipLayer := func(n string, minValue uint32) bool {
|
||||
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(n)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
blkNum, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return uint32(blkNum) >= minValue
|
||||
}
|
||||
|
||||
// Function to check if tensor should be skipped (vision components)
|
||||
skipVisionTensor := func(name string) bool {
|
||||
return strings.HasPrefix(name, "vision_") ||
|
||||
strings.HasPrefix(name, "patch_merger.") ||
|
||||
strings.Contains(name, "mm_projector")
|
||||
}
|
||||
|
||||
for _, t := range s {
|
||||
name := t.Name()
|
||||
|
||||
// Skip vision tensors (handled separately or not needed)
|
||||
if skipVisionTensor(name) {
|
||||
slog.Debug("skipping vision tensor", "name", name)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip any additional layers beyond expected count
|
||||
if skipLayer(name, p.NumLayers) {
|
||||
slog.Debug("skipping extra layer", "name", name)
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -101,8 +101,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
t.Pre = "deepseek-coder"
|
||||
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
|
||||
t.Pre = "qwen2"
|
||||
case "1d64a9a8eaf9f1bd80331984d81fdd514e7feafe8df83a525dd31472f275699a":
|
||||
t.Pre = "tekken"
|
||||
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
|
||||
// noop, empty pretokenizer
|
||||
default:
|
||||
|
||||
@@ -49,7 +49,8 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||
|
||||
// temporary fix to handle gemma3 broken configs
|
||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
|
||||
// TODO(parthsareen): allow reading of tokenizer.json to allow managing special tokens when using spm
|
||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>", "<start_function_declaration>", "<end_function_declaration>", "<start_function_call>", "<end_function_call>", "<start_function_response>", "<end_function_response>", "<escape>"}, piece.GetPiece()) {
|
||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ package deepseek2
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -40,10 +39,6 @@ type Options struct {
|
||||
ropeBase,
|
||||
ropeScale float32
|
||||
kqScale float64
|
||||
|
||||
attentionTemperatureScale float32
|
||||
attentionTemperatureLength int
|
||||
attentionTemperatureFloorScale int
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||
@@ -71,7 +66,7 @@ type Attention struct {
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
|
||||
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions, attentionScales ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
seqLength := hiddenStates.Dim(1)
|
||||
|
||||
var query ml.Tensor
|
||||
@@ -109,11 +104,6 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions, attentio
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
|
||||
if attentionScales != nil {
|
||||
query = query.Mul(ctx, attentionScales)
|
||||
}
|
||||
|
||||
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
} else { // v3.1
|
||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||
@@ -125,10 +115,6 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions, attentio
|
||||
key := kRot.Concat(ctx, kPass, 0)
|
||||
value := kPass
|
||||
|
||||
if attentionScales != nil {
|
||||
query = query.Mul(ctx, attentionScales)
|
||||
}
|
||||
|
||||
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||
}
|
||||
|
||||
@@ -215,10 +201,10 @@ type Layer struct {
|
||||
MLP MLP
|
||||
}
|
||||
|
||||
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, attentionScales, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, attentionScales, cache, opts)
|
||||
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
@@ -248,11 +234,7 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
// layers := make([]Layer, c.Uint("block_count"))
|
||||
// fmt.Printf("[MODEL DEBUG] Creating model with %d layers\n", c.Uint("block_count"))
|
||||
|
||||
layers := make([]Layer, 4)
|
||||
fmt.Printf("[MODEL DEBUG] Creating model with %d layers\n", 4)
|
||||
layers := make([]Layer, c.Uint("block_count"))
|
||||
|
||||
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
|
||||
for i := range layers {
|
||||
@@ -279,10 +261,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
`[一-龥-ゟ゠-ヿ]+`,
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
case "tekken":
|
||||
pre = []string{
|
||||
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
case "deepseek-llm":
|
||||
// TODO: these models haven't been vetted so skip for now
|
||||
// pre = []string{
|
||||
@@ -298,20 +276,13 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
// DEBUG: Check tokenizer vocabulary loading
|
||||
tokens := c.Strings("tokenizer.ggml.tokens")
|
||||
tokenTypes := c.Ints("tokenizer.ggml.token_type")
|
||||
merges := c.Strings("tokenizer.ggml.merges")
|
||||
|
||||
// Debug output removed for performance
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: tokens,
|
||||
Types: tokenTypes,
|
||||
Merges: merges,
|
||||
AddBOS: false, // c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
@@ -345,11 +316,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
|
||||
// TODO: double check these values
|
||||
attentionTemperatureScale: c.Float("attention.temperature_scale", 1.0),
|
||||
attentionTemperatureLength: int(c.Uint("attention.temperature_length")),
|
||||
attentionTemperatureFloorScale: int(c.Uint("attention.temperature_floor_scale", 8192)),
|
||||
|
||||
kqScale: kqScale,
|
||||
},
|
||||
}
|
||||
@@ -365,28 +331,8 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
// DEBUG: Check TokenEmbedding initialization
|
||||
if m.TokenEmbedding == nil {
|
||||
panic("DEBUG: m.TokenEmbedding is nil - 'token_embd' tensor not found in GGUF")
|
||||
}
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
// Temperature tuning - used by mistral-large
|
||||
var attentionScales ml.Tensor
|
||||
if m.attentionTemperatureScale != 0.0 {
|
||||
nTokens := len(batch.Positions)
|
||||
scales := make([]float32, nTokens)
|
||||
|
||||
for i, pos := range batch.Positions {
|
||||
posFloat := float64(pos)
|
||||
scaleValue := math.Log(math.Floor((posFloat+1.0)/float64(m.attentionTemperatureFloorScale))+1.0)*float64(m.attentionTemperatureScale) + 1.0
|
||||
scales[i] = float32(scaleValue)
|
||||
}
|
||||
|
||||
attentionScales = ctx.Input().FromFloats(scales, 1, 1, nTokens)
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
@@ -395,7 +341,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, attentionScales, outputs, m.Cache, m.Options)
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
|
||||
323
model/parsers/functiongemma.go
Normal file
323
model/parsers/functiongemma.go
Normal file
@@ -0,0 +1,323 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type FunctionGemmaParserState int
|
||||
|
||||
const (
|
||||
FunctionGemmaCollectingContent FunctionGemmaParserState = iota
|
||||
FunctionGemmaCollectingToolCalls
|
||||
)
|
||||
|
||||
const (
|
||||
functionGemmaFunctionCallOpen = "<start_function_call>"
|
||||
functionGemmaFunctionCallClose = "<end_function_call>"
|
||||
)
|
||||
|
||||
// This format uses <start_function_call>call:name{args}<end_function_call> for tool calls.
|
||||
type FunctionGemmaParser struct {
|
||||
state FunctionGemmaParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *FunctionGemmaParser) HasToolSupport() bool { return true }
|
||||
func (p *FunctionGemmaParser) HasThinkingSupport() bool { return false }
|
||||
|
||||
func (p *FunctionGemmaParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.state = FunctionGemmaCollectingContent
|
||||
return tools
|
||||
}
|
||||
|
||||
type functionGemmaEvent interface {
|
||||
isFunctionGemmaEvent()
|
||||
}
|
||||
|
||||
type FunctionGemmaEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type functionGemmaEventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (FunctionGemmaEventContent) isFunctionGemmaEvent() {}
|
||||
func (functionGemmaEventToolCall) isFunctionGemmaEvent() {}
|
||||
|
||||
func (p *FunctionGemmaParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case functionGemmaEventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case FunctionGemmaEventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), "", toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *FunctionGemmaParser) parseEvents() []functionGemmaEvent {
|
||||
var all []functionGemmaEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []functionGemmaEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// emitWithPartialCheck extracts unambiguous content before a potential partial tag
|
||||
func (p *FunctionGemmaParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) {
|
||||
if overlapLen := overlap(bufStr, tag); overlapLen > 0 {
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
return beforePartialTag, bufStr[len(beforePartialTag):]
|
||||
}
|
||||
return bufStr, ""
|
||||
}
|
||||
|
||||
func (p *FunctionGemmaParser) eat() ([]functionGemmaEvent, bool) {
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case FunctionGemmaCollectingContent:
|
||||
if strings.Contains(bufStr, functionGemmaFunctionCallOpen) {
|
||||
split := strings.SplitN(bufStr, functionGemmaFunctionCallOpen, 2)
|
||||
content := split[0]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(split[1])
|
||||
p.state = FunctionGemmaCollectingToolCalls
|
||||
if content != "" {
|
||||
return []functionGemmaEvent{FunctionGemmaEventContent{content: content}}, true
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
unambig, ambig := p.emitWithPartialCheck(bufStr, functionGemmaFunctionCallOpen)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambig)
|
||||
if unambig != "" {
|
||||
return []functionGemmaEvent{FunctionGemmaEventContent{content: unambig}}, false
|
||||
}
|
||||
return nil, false
|
||||
|
||||
case FunctionGemmaCollectingToolCalls:
|
||||
if strings.Contains(bufStr, functionGemmaFunctionCallClose) {
|
||||
split := strings.SplitN(bufStr, functionGemmaFunctionCallClose, 2)
|
||||
remaining := split[1]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
|
||||
var events []functionGemmaEvent
|
||||
if tc, err := p.parseToolCall(split[0]); err == nil {
|
||||
events = append(events, functionGemmaEventToolCall{toolCall: tc})
|
||||
}
|
||||
|
||||
if !strings.Contains(remaining, functionGemmaFunctionCallOpen) {
|
||||
p.state = FunctionGemmaCollectingContent
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Matches call:function_name{args}
|
||||
var functionGemmaCallRegex = regexp.MustCompile(`call:([^{]+)\{(.*)\}`)
|
||||
|
||||
func (p *FunctionGemmaParser) parseToolCall(content string) (api.ToolCall, error) {
|
||||
toolCall := api.ToolCall{}
|
||||
|
||||
// Extract function name and arguments
|
||||
match := functionGemmaCallRegex.FindStringSubmatch(content)
|
||||
if len(match) < 3 {
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
toolCall.Function.Name = match[1]
|
||||
argsStr := match[2]
|
||||
|
||||
// Parse arguments
|
||||
toolCall.Function.Arguments = p.parseArguments(argsStr)
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
// parseArguments parses the key:value,key:value format
|
||||
func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctionArguments {
|
||||
args := make(api.ToolCallFunctionArguments)
|
||||
if argsStr == "" {
|
||||
return args
|
||||
}
|
||||
|
||||
// Split by comma, but handle nested structures
|
||||
parts := p.splitArguments(argsStr)
|
||||
|
||||
for _, part := range parts {
|
||||
// Find the first colon to split key:value
|
||||
colonIdx := strings.Index(part, ":")
|
||||
if colonIdx == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := part[:colonIdx]
|
||||
value := part[colonIdx+1:]
|
||||
|
||||
// Parse the value
|
||||
args[key] = p.parseValue(value)
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// splitArguments splits arguments by comma, respecting nested structures
|
||||
func (p *FunctionGemmaParser) splitArguments(argsStr string) []string {
|
||||
var parts []string
|
||||
var current strings.Builder
|
||||
depth := 0
|
||||
inEscape := false
|
||||
|
||||
for i := 0; i < len(argsStr); i++ {
|
||||
ch := argsStr[i]
|
||||
|
||||
// Check for <escape> tags
|
||||
if i+8 <= len(argsStr) && argsStr[i:i+8] == "<escape>" {
|
||||
inEscape = !inEscape
|
||||
current.WriteString("<escape>")
|
||||
i += 7 // Skip the rest of <escape>
|
||||
continue
|
||||
}
|
||||
|
||||
if !inEscape {
|
||||
switch ch {
|
||||
case '{', '[':
|
||||
depth++
|
||||
current.WriteByte(ch)
|
||||
case '}', ']':
|
||||
depth--
|
||||
current.WriteByte(ch)
|
||||
case ',':
|
||||
if depth == 0 {
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
continue
|
||||
}
|
||||
current.WriteByte(ch)
|
||||
default:
|
||||
current.WriteByte(ch)
|
||||
}
|
||||
} else {
|
||||
current.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// parseValue parses a single value from the FunctionGemma format
|
||||
func (p *FunctionGemmaParser) parseValue(value string) any {
|
||||
// Check for escaped string
|
||||
if strings.HasPrefix(value, "<escape>") && strings.HasSuffix(value, "<escape>") {
|
||||
// Remove the escape tags
|
||||
return value[8 : len(value)-8]
|
||||
}
|
||||
|
||||
// Check for boolean
|
||||
if value == "true" {
|
||||
return true
|
||||
}
|
||||
if value == "false" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for number
|
||||
if num, ok := parseNumber(value); ok {
|
||||
return num
|
||||
}
|
||||
|
||||
// Check for array
|
||||
if strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") {
|
||||
return p.parseArray(value[1 : len(value)-1])
|
||||
}
|
||||
|
||||
// Check for object
|
||||
if strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}") {
|
||||
return p.parseObject(value[1 : len(value)-1])
|
||||
}
|
||||
|
||||
// Default to string
|
||||
return value
|
||||
}
|
||||
|
||||
// parseArray parses an array value
|
||||
func (p *FunctionGemmaParser) parseArray(content string) []any {
|
||||
var result []any
|
||||
parts := p.splitArguments(content)
|
||||
for _, part := range parts {
|
||||
result = append(result, p.parseValue(part))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// parseObject parses an object value
|
||||
func (p *FunctionGemmaParser) parseObject(content string) map[string]any {
|
||||
result := make(map[string]any)
|
||||
parts := p.splitArguments(content)
|
||||
for _, part := range parts {
|
||||
colonIdx := strings.Index(part, ":")
|
||||
if colonIdx == -1 {
|
||||
continue
|
||||
}
|
||||
key := part[:colonIdx]
|
||||
value := part[colonIdx+1:]
|
||||
result[key] = p.parseValue(value)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// parseNumber tries to parse a string as a number
|
||||
func parseNumber(s string) (any, bool) {
|
||||
// Try integer first
|
||||
var intVal int64
|
||||
if _, err := fmt.Sscanf(s, "%d", &intVal); err == nil {
|
||||
// Check if the entire string was consumed
|
||||
if fmt.Sprintf("%d", intVal) == s {
|
||||
return intVal, true
|
||||
}
|
||||
}
|
||||
|
||||
// Try float
|
||||
var floatVal float64
|
||||
if _, err := fmt.Sscanf(s, "%f", &floatVal); err == nil {
|
||||
return floatVal, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
423
model/parsers/functiongemma_test.go
Normal file
423
model/parsers/functiongemma_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFunctionGemmaParser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks []string
|
||||
tools []api.Tool
|
||||
expectedCalls []api.ToolCall
|
||||
expectedText string
|
||||
}{
|
||||
{
|
||||
name: "plain_content",
|
||||
chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"},
|
||||
expectedCalls: nil,
|
||||
expectedText: "Hello, world!",
|
||||
},
|
||||
{
|
||||
name: "simple_tool_call",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "weather", "{",
|
||||
"city", ":", "<", "escape", ">", "Paris", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "content_before_tool_call",
|
||||
chunks: []string{
|
||||
"L", "et", " ", "me", " ", "check", ".",
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "weather", "{",
|
||||
"city", ":", "<", "escape", ">", "Paris", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "Let me check.",
|
||||
},
|
||||
{
|
||||
name: "numeric_arguments",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "add", "{",
|
||||
"a", ":", "1", ",", "b", ":", "2",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "add",
|
||||
Arguments: api.ToolCallFunctionArguments{"a": int64(1), "b": int64(2)},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "boolean_arguments",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "set", "_", "flag", "{",
|
||||
"enabled", ":", "true", ",", "verbose", ":", "false",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_flag",
|
||||
Arguments: api.ToolCallFunctionArguments{"enabled": true, "verbose": false},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "multiple_tool_calls",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "weather", "{",
|
||||
"city", ":", "<", "escape", ">", "Paris", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "weather", "{",
|
||||
"city", ":", "<", "escape", ">", "London", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "London"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "array_argument",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "process", "{",
|
||||
"items", ":", "[",
|
||||
"<", "escape", ">", "a", "<", "escape", ">", ",",
|
||||
"<", "escape", ">", "b", "<", "escape", ">", ",",
|
||||
"<", "escape", ">", "c", "<", "escape", ">",
|
||||
"]",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: api.ToolCallFunctionArguments{"items": []any{"a", "b", "c"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "object_argument",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "update", "{",
|
||||
"data", ":", "{",
|
||||
"name", ":", "<", "escape", ">", "test", "<", "escape", ">", ",",
|
||||
"value", ":", "42",
|
||||
"}",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "update",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"data": map[string]any{"name": "test", "value": int64(42)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "empty_input",
|
||||
chunks: []string{},
|
||||
expectedCalls: nil,
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_no_arguments",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "time", "{", "}",
|
||||
"<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "content_with_angle_brackets",
|
||||
chunks: []string{
|
||||
"The", " ", "result", " ", "is", " ", "a", " ", "<", "value", ">", " ", "tag",
|
||||
},
|
||||
expectedCalls: nil,
|
||||
expectedText: "The result is a <value> tag",
|
||||
},
|
||||
{
|
||||
name: "float_argument",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "set", "_", "temp", "{",
|
||||
"value", ":", "3", ".", "14",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temp",
|
||||
Arguments: api.ToolCallFunctionArguments{"value": 3.14},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "content_after_tool_call",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "test", "{", "}",
|
||||
"<", "end", "_", "function", "_", "call", ">",
|
||||
"Done", "!",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "Done!",
|
||||
},
|
||||
{
|
||||
name: "unicode_content_and_arguments",
|
||||
chunks: []string{
|
||||
"こんにちは", " ",
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "greet", "{",
|
||||
"name", ":", "<", "escape", ">", "日本語", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "greet",
|
||||
Arguments: api.ToolCallFunctionArguments{"name": "日本語"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "こんにちは ",
|
||||
},
|
||||
{
|
||||
name: "multiple_params_sorted",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "search", "{",
|
||||
"query", ":", "<", "escape", ">", "test", "<", "escape", ">", ",",
|
||||
"limit", ":", "10", ",",
|
||||
"offset", ":", "0",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"query": "test",
|
||||
"limit": int64(10),
|
||||
"offset": int64(0),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "nested_object_argument",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "create", "{",
|
||||
"config", ":", "{",
|
||||
"settings", ":", "{",
|
||||
"enabled", ":", "true", ",",
|
||||
"name", ":", "<", "escape", ">", "test", "<", "escape", ">",
|
||||
"}",
|
||||
"}",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"config": map[string]any{
|
||||
"settings": map[string]any{
|
||||
"enabled": true,
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "partial_start_tag_in_content",
|
||||
chunks: []string{
|
||||
"Hello", " ", "<", "start", " ", "world",
|
||||
},
|
||||
expectedCalls: nil,
|
||||
expectedText: "Hello <start world",
|
||||
},
|
||||
{
|
||||
name: "parallel_tool_calls",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "weather", "{",
|
||||
"city", ":", "<", "escape", ">", "Paris", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "time", "{",
|
||||
"timezone", ":", "<", "escape", ">", "UTC", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "content_between_tool_calls",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "first", "{", "}",
|
||||
"<", "end", "_", "function", "_", "call", ">",
|
||||
"Some", " ", "text", " ", "here",
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "second", "{", "}",
|
||||
"<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "first",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "second",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "Some text here",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &FunctionGemmaParser{}
|
||||
parser.Init(tt.tools, nil, nil)
|
||||
|
||||
var allContent string
|
||||
var allCalls []api.ToolCall
|
||||
|
||||
for i, chunk := range tt.chunks {
|
||||
done := i == len(tt.chunks)-1
|
||||
content, _, calls, err := parser.Add(chunk, done)
|
||||
assert.NoError(t, err)
|
||||
allContent += content
|
||||
allCalls = append(allCalls, calls...)
|
||||
}
|
||||
|
||||
// Handle empty chunks case
|
||||
if len(tt.chunks) == 0 {
|
||||
content, _, calls, err := parser.Add("", true)
|
||||
assert.NoError(t, err)
|
||||
allContent = content
|
||||
allCalls = calls
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expectedText, allContent)
|
||||
assert.Equal(t, tt.expectedCalls, allCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionGemmaParser_HasSupport(t *testing.T) {
|
||||
parser := &FunctionGemmaParser{}
|
||||
assert.True(t, parser.HasToolSupport())
|
||||
assert.False(t, parser.HasThinkingSupport())
|
||||
}
|
||||
@@ -66,6 +66,8 @@ func ParserForName(name string) Parser {
|
||||
return &Olmo3ThinkParser{}
|
||||
case "nemotron-3-nano":
|
||||
return &Nemotron3NanoParser{}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaParser{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
287
model/renderers/functiongemma.go
Normal file
287
model/renderers/functiongemma.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type FunctionGemmaRenderer struct{}
|
||||
|
||||
const defaultSystemMessage = "You can do function calling with the following functions:"
|
||||
|
||||
func (r *FunctionGemmaRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("<bos>")
|
||||
|
||||
var systemMessage string
|
||||
var loopMessages []api.Message
|
||||
if len(messages) > 0 && (messages[0].Role == "system" || messages[0].Role == "developer") {
|
||||
systemMessage = messages[0].Content
|
||||
loopMessages = messages[1:]
|
||||
} else {
|
||||
loopMessages = messages
|
||||
}
|
||||
|
||||
if systemMessage != "" || len(tools) > 0 {
|
||||
sb.WriteString("<start_of_turn>developer\n")
|
||||
if systemMessage != "" {
|
||||
sb.WriteString(strings.TrimSpace(systemMessage))
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
if systemMessage != "" {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
if strings.TrimSpace(systemMessage) != defaultSystemMessage {
|
||||
// Only add default message if user does not provide it
|
||||
sb.WriteString(defaultSystemMessage)
|
||||
}
|
||||
}
|
||||
for _, tool := range tools {
|
||||
sb.WriteString(r.renderToolDeclaration(tool))
|
||||
}
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
}
|
||||
|
||||
// Track previous message type for tool response handling
|
||||
prevMessageType := ""
|
||||
|
||||
for i, message := range loopMessages {
|
||||
switch message.Role {
|
||||
case "assistant":
|
||||
if prevMessageType != "tool_response" {
|
||||
sb.WriteString("<start_of_turn>model\n")
|
||||
}
|
||||
prevMessageType = ""
|
||||
|
||||
if message.Content != "" {
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, tc := range message.ToolCalls {
|
||||
sb.WriteString(r.formatToolCall(tc))
|
||||
}
|
||||
// After tool calls, expect tool responses
|
||||
if i+1 < len(loopMessages) && loopMessages[i+1].Role == "tool" {
|
||||
sb.WriteString("<start_function_response>")
|
||||
prevMessageType = "tool_call"
|
||||
} else {
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
}
|
||||
|
||||
case "user":
|
||||
if prevMessageType != "tool_response" {
|
||||
sb.WriteString("<start_of_turn>user\n")
|
||||
}
|
||||
prevMessageType = ""
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
|
||||
case "tool":
|
||||
toolName := ""
|
||||
// Find the tool name from the previous assistant's tool call
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
if loopMessages[j].Role == "assistant" && len(loopMessages[j].ToolCalls) > 0 {
|
||||
// Count how many tool messages came before this one
|
||||
toolIdx := 0
|
||||
for k := j + 1; k < i; k++ {
|
||||
if loopMessages[k].Role == "tool" {
|
||||
toolIdx++
|
||||
}
|
||||
}
|
||||
if toolIdx < len(loopMessages[j].ToolCalls) {
|
||||
toolName = loopMessages[j].ToolCalls[toolIdx].Function.Name
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if prevMessageType != "tool_call" {
|
||||
sb.WriteString("<start_function_response>")
|
||||
}
|
||||
sb.WriteString("response:" + toolName + "{" + r.formatArgValue(message.Content) + "}<end_function_response>")
|
||||
prevMessageType = "tool_response"
|
||||
|
||||
default:
|
||||
sb.WriteString("<start_of_turn>" + message.Role + "\n")
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
}
|
||||
}
|
||||
|
||||
if prevMessageType != "tool_response" {
|
||||
sb.WriteString("<start_of_turn>model\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
|
||||
var sb strings.Builder
|
||||
|
||||
fn := tool.Function
|
||||
sb.WriteString("<start_function_declaration>declaration:" + fn.Name + "{")
|
||||
sb.WriteString("description:<escape>" + fn.Description + "<escape>")
|
||||
|
||||
if fn.Parameters.Properties != nil || fn.Parameters.Type != "" {
|
||||
sb.WriteString(",parameters:{")
|
||||
|
||||
needsComma := false
|
||||
|
||||
// Only include properties:{} if there are actual properties
|
||||
if len(fn.Parameters.Properties) > 0 {
|
||||
sb.WriteString("properties:{")
|
||||
r.writeProperties(&sb, fn.Parameters.Properties)
|
||||
sb.WriteString("}")
|
||||
needsComma = true
|
||||
}
|
||||
|
||||
if len(fn.Parameters.Required) > 0 {
|
||||
if needsComma {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("required:[")
|
||||
for i, req := range fn.Parameters.Required {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("<escape>" + req + "<escape>")
|
||||
}
|
||||
sb.WriteString("]")
|
||||
needsComma = true
|
||||
}
|
||||
|
||||
if fn.Parameters.Type != "" {
|
||||
if needsComma {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("type:<escape>" + strings.ToUpper(fn.Parameters.Type) + "<escape>")
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
}
|
||||
|
||||
sb.WriteString("}<end_function_declaration>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props map[string]api.ToolProperty) {
|
||||
keys := make([]string, 0, len(props))
|
||||
for k := range props {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, name := range keys {
|
||||
prop := props[name]
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
|
||||
sb.WriteString(name + ":{description:<escape>")
|
||||
sb.WriteString(prop.Description)
|
||||
sb.WriteString("<escape>")
|
||||
|
||||
if len(prop.Type) > 0 {
|
||||
sb.WriteString(",type:<escape>" + strings.ToUpper(prop.Type[0]) + "<escape>")
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) formatToolCall(tc api.ToolCall) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<start_function_call>call:" + tc.Function.Name + "{")
|
||||
|
||||
keys := make([]string, 0, len(tc.Function.Arguments))
|
||||
for k := range tc.Function.Arguments {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
value := tc.Function.Arguments[key]
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(value))
|
||||
}
|
||||
|
||||
sb.WriteString("}<end_function_call>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) formatArgValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return "<escape>" + v + "<escape>"
|
||||
case bool:
|
||||
if v {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
case float64:
|
||||
if v == float64(int64(v)) {
|
||||
return fmt.Sprintf("%d", int64(v))
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
case int, int64, int32:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case map[string]any:
|
||||
return r.formatMapValue(v)
|
||||
case []any:
|
||||
return r.formatArrayValue(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) formatMapValue(m map[string]any) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("{")
|
||||
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(m[key]))
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) formatArrayValue(arr []any) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[")
|
||||
|
||||
for i, item := range arr {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(r.formatArgValue(item))
|
||||
}
|
||||
|
||||
sb.WriteString("]")
|
||||
return sb.String()
|
||||
}
|
||||
514
model/renderers/functiongemma_test.go
Normal file
514
model/renderers/functiongemma_test.go
Normal file
@@ -0,0 +1,514 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic_user_message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>user\nHello!<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "with_system_message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou are helpful<end_of_turn>\n<start_of_turn>user\nHello!<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "with_developer_role",
|
||||
messages: []api.Message{
|
||||
{Role: "developer", Content: "You are a coding assistant"},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou are a coding assistant<end_of_turn>\n<start_of_turn>user\nHello!<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "custom_system_message_with_tools",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a weather expert."},
|
||||
{Role: "user", Content: "Weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Custom system message is preserved, tools are appended
|
||||
expected: "<bos><start_of_turn>developer\nYou are a weather expert.\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "developer_role_with_tools",
|
||||
messages: []api.Message{
|
||||
{Role: "developer", Content: "Be concise."},
|
||||
{Role: "user", Content: "Weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Developer role message is preserved, tools are appended
|
||||
expected: "<bos><start_of_turn>developer\nBe concise.\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "multi_turn",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello!"},
|
||||
{Role: "user", Content: "More"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>user\nHi<end_of_turn>\n<start_of_turn>model\nHello!<end_of_turn>\n<start_of_turn>user\nMore<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "with_tools",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "tool_call",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:get_weather{city:<escape>Paris<escape>}<end_function_call><start_function_response>response:get_weather{<escape>Sunny<escape>}<end_function_response>",
|
||||
},
|
||||
{
|
||||
name: "assistant_content_with_tool_call",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\nLet me check.<start_function_call>call:get_weather{city:<escape>Paris<escape>}<end_function_call><start_function_response>response:get_weather{<escape>Sunny<escape>}<end_function_response>",
|
||||
},
|
||||
{
|
||||
name: "numeric_arguments",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Add"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "add",
|
||||
Arguments: api.ToolCallFunctionArguments{"a": float64(1), "b": float64(2)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "3"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "add",
|
||||
Description: "Add numbers",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"a": {Type: api.PropertyType{"number"}},
|
||||
"b": {Type: api.PropertyType{"number"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:add{description:<escape>Add numbers<escape>,parameters:{properties:{a:{description:<escape><escape>,type:<escape>NUMBER<escape>},b:{description:<escape><escape>,type:<escape>NUMBER<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nAdd<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:add{a:1,b:2}<end_function_call><start_function_response>response:add{<escape>3<escape>}<end_function_response>",
|
||||
},
|
||||
{
|
||||
name: "empty_messages",
|
||||
messages: []api.Message{},
|
||||
expected: "<bos><start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "tool_with_required_params",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Gets the weather for a given city",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City Name"},
|
||||
"country": {Type: api.PropertyType{"string"}, Description: "Country Name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Required params are escaped: required:[<escape>city<escape>]
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Gets the weather for a given city<escape>,parameters:{properties:{city:{description:<escape>City Name<escape>,type:<escape>STRING<escape>},country:{description:<escape>Country Name<escape>,type:<escape>STRING<escape>}},required:[<escape>city<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "multiple_tools",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather and time?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_time",
|
||||
Description: "Get current time",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Multiple tool declarations are consecutive
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_declaration>declaration:get_time{description:<escape>Get current time<escape>,parameters:{properties:{timezone:{description:<escape>Timezone<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather and time?<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "parallel_tool_calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather and time?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny"},
|
||||
{Role: "tool", Content: "12:00"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_time",
|
||||
Description: "Get current time",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Multiple tool calls and responses are consecutive
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_declaration>declaration:get_time{description:<escape>Get current time<escape>,parameters:{properties:{timezone:{description:<escape>Timezone<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather and time?<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:get_weather{city:<escape>Paris<escape>}<end_function_call><start_function_call>call:get_time{timezone:<escape>UTC<escape>}<end_function_call><start_function_response>response:get_weather{<escape>Sunny<escape>}<end_function_response><start_function_response>response:get_time{<escape>12:00<escape>}<end_function_response>",
|
||||
},
|
||||
{
|
||||
name: "user_after_tool_response",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny"},
|
||||
{Role: "user", Content: "Thanks! What about London?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// User message after tool response gets concatenated (user reverted to this behavior)
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:get_weather{city:<escape>Paris<escape>}<end_function_call><start_function_response>response:get_weather{<escape>Sunny<escape>}<end_function_response>Thanks! What about London?<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
// Edge cases
|
||||
{
|
||||
name: "tool_empty_properties",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Test"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "test_fn",
|
||||
Description: "",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Empty properties are omitted
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:test_fn{description:<escape><escape>,parameters:{type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nTest<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "unicode_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "こんにちは 🎉"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>user\nこんにちは 🎉<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "newlines_in_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Line 1\nLine 2\nLine 3"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>user\nLine 1\nLine 2\nLine 3<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "special_chars_in_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Test <tag> & \"quotes\" chars"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>user\nTest <tag> & \"quotes\" chars<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "boolean_argument",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Set flag"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_flag",
|
||||
Arguments: api.ToolCallFunctionArguments{"enabled": true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "done"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "set_flag",
|
||||
Description: "Set a flag",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"enabled": {Type: api.PropertyType{"boolean"}, Description: "Flag value"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:set_flag{description:<escape>Set a flag<escape>,parameters:{properties:{enabled:{description:<escape>Flag value<escape>,type:<escape>BOOLEAN<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nSet flag<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:set_flag{enabled:true}<end_function_call><start_function_response>response:set_flag{<escape>done<escape>}<end_function_response>",
|
||||
},
|
||||
{
|
||||
name: "multiple_required_params",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Test"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "test",
|
||||
Description: "Test",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"a", "b", "c"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"a": {Type: api.PropertyType{"string"}, Description: "A"},
|
||||
"b": {Type: api.PropertyType{"string"}, Description: "B"},
|
||||
"c": {Type: api.PropertyType{"string"}, Description: "C"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:test{description:<escape>Test<escape>,parameters:{properties:{a:{description:<escape>A<escape>,type:<escape>STRING<escape>},b:{description:<escape>B<escape>,type:<escape>STRING<escape>},c:{description:<escape>C<escape>,type:<escape>STRING<escape>}},required:[<escape>a<escape>,<escape>b<escape>,<escape>c<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nTest<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "array_type_param",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Test"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "test",
|
||||
Description: "Test",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"items": {Type: api.PropertyType{"array"}, Description: "List of items"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:test{description:<escape>Test<escape>,parameters:{properties:{items:{description:<escape>List of items<escape>,type:<escape>ARRAY<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nTest<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
renderer := &FunctionGemmaRenderer{}
|
||||
result, err := renderer.Render(tt.messages, tt.tools, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -78,6 +78,8 @@ func rendererForName(name string) Renderer {
|
||||
return renderer
|
||||
case "nemotron-3-nano":
|
||||
return &Nemotron3NanoRenderer{}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaRenderer{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user