Compare commits
1 Commits
v0.13.1
...
parth/remo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b152860c2 |
@@ -11,6 +11,7 @@ linters:
|
|||||||
- errorlint
|
- errorlint
|
||||||
- exptostd
|
- exptostd
|
||||||
- gocheckcompilerdirectives
|
- gocheckcompilerdirectives
|
||||||
|
- gocritic
|
||||||
- govet
|
- govet
|
||||||
- ineffassign
|
- ineffassign
|
||||||
- intrange
|
- intrange
|
||||||
@@ -22,7 +23,6 @@ linters:
|
|||||||
- nolintlint
|
- nolintlint
|
||||||
- nosprintfhostport
|
- nosprintfhostport
|
||||||
- perfsprint
|
- perfsprint
|
||||||
- prealloc
|
|
||||||
- sloglint
|
- sloglint
|
||||||
- staticcheck
|
- staticcheck
|
||||||
- unconvert
|
- unconvert
|
||||||
|
|||||||
@@ -226,14 +226,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
|
|
||||||
bts := scanner.Bytes()
|
bts := scanner.Bytes()
|
||||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||||
if response.StatusCode >= http.StatusBadRequest {
|
return fmt.Errorf("unmarshal: %w", err)
|
||||||
return StatusError{
|
|
||||||
StatusCode: response.StatusCode,
|
|
||||||
Status: response.Status,
|
|
||||||
ErrorMessage: string(bts),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return errors.New(string(bts))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.StatusCode == http.StatusUnauthorized {
|
if response.StatusCode == http.StatusUnauthorized {
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ func TestClientFromEnvironment(t *testing.T) {
|
|||||||
type testError struct {
|
type testError struct {
|
||||||
message string
|
message string
|
||||||
statusCode int
|
statusCode int
|
||||||
raw bool // if true, write message as-is instead of JSON encoding
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e testError) Error() string {
|
func (e testError) Error() string {
|
||||||
@@ -112,20 +111,6 @@ func TestClientStream(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "plain text error response",
|
|
||||||
responses: []any{
|
|
||||||
"internal server error",
|
|
||||||
},
|
|
||||||
wantErr: "internal server error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "HTML error page",
|
|
||||||
responses: []any{
|
|
||||||
"<html><body>404 Not Found</body></html>",
|
|
||||||
},
|
|
||||||
wantErr: "404 Not Found",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@@ -150,12 +135,6 @@ func TestClientStream(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if str, ok := resp.(string); ok {
|
|
||||||
fmt.Fprintln(w, str)
|
|
||||||
flusher.Flush()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
t.Fatalf("failed to encode response: %v", err)
|
t.Fatalf("failed to encode response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -194,10 +173,9 @@ func TestClientStream(t *testing.T) {
|
|||||||
|
|
||||||
func TestClientDo(t *testing.T) {
|
func TestClientDo(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
response any
|
response any
|
||||||
wantErr string
|
wantErr string
|
||||||
wantStatusCode int
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "immediate error response",
|
name: "immediate error response",
|
||||||
@@ -205,8 +183,7 @@ func TestClientDo(t *testing.T) {
|
|||||||
message: "test error message",
|
message: "test error message",
|
||||||
statusCode: http.StatusBadRequest,
|
statusCode: http.StatusBadRequest,
|
||||||
},
|
},
|
||||||
wantErr: "test error message",
|
wantErr: "test error message",
|
||||||
wantStatusCode: http.StatusBadRequest,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "server error response",
|
name: "server error response",
|
||||||
@@ -214,8 +191,7 @@ func TestClientDo(t *testing.T) {
|
|||||||
message: "internal error",
|
message: "internal error",
|
||||||
statusCode: http.StatusInternalServerError,
|
statusCode: http.StatusInternalServerError,
|
||||||
},
|
},
|
||||||
wantErr: "internal error",
|
wantErr: "internal error",
|
||||||
wantStatusCode: http.StatusInternalServerError,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "successful response",
|
name: "successful response",
|
||||||
@@ -227,26 +203,6 @@ func TestClientDo(t *testing.T) {
|
|||||||
Success: true,
|
Success: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "plain text error response",
|
|
||||||
response: testError{
|
|
||||||
message: "internal server error",
|
|
||||||
statusCode: http.StatusInternalServerError,
|
|
||||||
raw: true,
|
|
||||||
},
|
|
||||||
wantErr: "internal server error",
|
|
||||||
wantStatusCode: http.StatusInternalServerError,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "HTML error page",
|
|
||||||
response: testError{
|
|
||||||
message: "<html><body>404 Not Found</body></html>",
|
|
||||||
statusCode: http.StatusNotFound,
|
|
||||||
raw: true,
|
|
||||||
},
|
|
||||||
wantErr: "<html><body>404 Not Found</body></html>",
|
|
||||||
wantStatusCode: http.StatusNotFound,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@@ -254,16 +210,11 @@ func TestClientDo(t *testing.T) {
|
|||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if errResp, ok := tc.response.(testError); ok {
|
if errResp, ok := tc.response.(testError); ok {
|
||||||
w.WriteHeader(errResp.statusCode)
|
w.WriteHeader(errResp.statusCode)
|
||||||
if !errResp.raw {
|
err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
err := json.NewEncoder(w).Encode(map[string]string{
|
"error": errResp.message,
|
||||||
"error": errResp.message,
|
})
|
||||||
})
|
if err != nil {
|
||||||
if err != nil {
|
t.Fatal("failed to encode error response:", err)
|
||||||
t.Fatal("failed to encode error response:", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Write raw message (simulates non-JSON error responses)
|
|
||||||
fmt.Fprint(w, errResp.message)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -290,15 +241,6 @@ func TestClientDo(t *testing.T) {
|
|||||||
if err.Error() != tc.wantErr {
|
if err.Error() != tc.wantErr {
|
||||||
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
||||||
}
|
}
|
||||||
if tc.wantStatusCode != 0 {
|
|
||||||
if statusErr, ok := err.(StatusError); ok {
|
|
||||||
if statusErr.StatusCode != tc.wantStatusCode {
|
|
||||||
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
t.Errorf("expected StatusError, got %T", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1430,7 +1430,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||||||
latest.Summary()
|
latest.Summary()
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generate(cmd *cobra.Command, opts runOptions) error {
|
func generate(cmd *cobra.Command, opts runOptions) error {
|
||||||
|
|||||||
@@ -29,15 +29,6 @@ type mistral3Model struct {
|
|||||||
SlidingWindow *uint32 `json:"sliding_window"`
|
SlidingWindow *uint32 `json:"sliding_window"`
|
||||||
HiddenAct string `json:"hidden_act"`
|
HiddenAct string `json:"hidden_act"`
|
||||||
VocabSize uint32 `json:"vocab_size"`
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
RopeParameters struct {
|
|
||||||
BetaFast float32 `json:"beta_fast"`
|
|
||||||
BetaSlow float32 `json:"beta_slow"`
|
|
||||||
Factor float32 `json:"factor"`
|
|
||||||
ScalingBeta float32 `json:"llama_4_scaling_beta"`
|
|
||||||
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
|
||||||
RopeType string `json:"rope_type"`
|
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
|
||||||
} `json:"rope_parameters"`
|
|
||||||
} `json:"text_config"`
|
} `json:"text_config"`
|
||||||
VisionModel struct {
|
VisionModel struct {
|
||||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
@@ -70,13 +61,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||||
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||||
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||||
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
|
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
||||||
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
|
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
||||||
|
|
||||||
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
|
||||||
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
|
|
||||||
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
|
|
||||||
}
|
|
||||||
|
|
||||||
// Vision configuration
|
// Vision configuration
|
||||||
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||||
|
|||||||
@@ -65,7 +65,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
|||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("discovering available GPUs...")
|
slog.Info("discovering available GPUs...")
|
||||||
detectIncompatibleLibraries()
|
|
||||||
|
|
||||||
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
|
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
|
||||||
overrideWarnings()
|
overrideWarnings()
|
||||||
@@ -488,16 +487,3 @@ func overrideWarnings() {
|
|||||||
slog.Warn("if GPUs are not correctly discovered, unset and try again")
|
slog.Warn("if GPUs are not correctly discovered, unset and try again")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func detectIncompatibleLibraries() {
|
|
||||||
if runtime.GOOS != "windows" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
basePath, err := exec.LookPath("ggml-base.dll")
|
|
||||||
if err != nil || basePath == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
|
|
||||||
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -33,9 +33,6 @@ func TestVisionModels(t *testing.T) {
|
|||||||
// Qwen 3 VL mixture of experts
|
// Qwen 3 VL mixture of experts
|
||||||
model: "qwen3-vl:30b",
|
model: "qwen3-vl:30b",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
model: "ministral-3",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range testCases {
|
for _, v := range testCases {
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ var (
|
|||||||
|
|
||||||
// Note: add newer models at the top of the list to test them first
|
// Note: add newer models at the top of the list to test them first
|
||||||
ollamaEngineChatModels = []string{
|
ollamaEngineChatModels = []string{
|
||||||
"ministral-3",
|
|
||||||
"qwen3-coder:30b",
|
"qwen3-coder:30b",
|
||||||
"gpt-oss:20b",
|
"gpt-oss:20b",
|
||||||
"gemma3n:e2b",
|
"gemma3n:e2b",
|
||||||
@@ -168,7 +167,6 @@ var (
|
|||||||
"medllama2",
|
"medllama2",
|
||||||
"megadolphin",
|
"megadolphin",
|
||||||
"minicpm-v",
|
"minicpm-v",
|
||||||
"ministral-3",
|
|
||||||
"mistral-large",
|
"mistral-large",
|
||||||
"mistral-nemo",
|
"mistral-nemo",
|
||||||
"mistral-openorca",
|
"mistral-openorca",
|
||||||
@@ -272,7 +270,6 @@ var (
|
|||||||
"mistral",
|
"mistral",
|
||||||
"qwen2.5",
|
"qwen2.5",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"ministral-3",
|
|
||||||
"mistral-nemo",
|
"mistral-nemo",
|
||||||
"mistral-small",
|
"mistral-small",
|
||||||
"mixtral:8x22b",
|
"mixtral:8x22b",
|
||||||
|
|||||||
@@ -874,7 +874,7 @@ func (s *llmServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.Devic
|
|||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
gpuLayers, layers := s.buildLayout(systemGPUs, memory, requireFull, backoff)
|
gpuLayers, layers := s.buildLayout(systemGPUs, memory, requireFull, backoff)
|
||||||
err := s.verifyLayout(systemInfo, systemGPUs, memory, requireFull, gpuLayers, layers)
|
err := s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -943,7 +943,7 @@ func (s *llmServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// verifyLayout ensures that we don't exceed limits, such as requirements about partial offloading or system memory
|
// verifyLayout ensures that we don't exceed limits, such as requirements about partial offloading or system memory
|
||||||
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
|
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
|
||||||
// These sizes will only increase as we go through additional iterations and get additional information.
|
// These sizes will only increase as we go through additional iterations and get additional information.
|
||||||
cpuSize := memory.InputWeights + memory.CPU.Graph
|
cpuSize := memory.InputWeights + memory.CPU.Graph
|
||||||
var vramSize uint64
|
var vramSize uint64
|
||||||
@@ -970,8 +970,8 @@ nextLayer:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if requireFull {
|
if requireFull {
|
||||||
if len(systemGPUs) > 0 && gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
|
if gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
|
||||||
slog.Info("model requires more gpu memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
|
slog.Info("model requires more memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
|
||||||
return ErrLoadRequiredFull
|
return ErrLoadRequiredFull
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -998,7 +998,7 @@ nextLayer:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(systemGPUs) > 0 && gpuLayers.Sum() == 0 {
|
if gpuLayers.Sum() == 0 {
|
||||||
slog.Debug("insufficient VRAM to load any model layers")
|
slog.Debug("insufficient VRAM to load any model layers")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,11 +26,10 @@ func TestLLMServerFitGPU(t *testing.T) {
|
|||||||
expectedErr error
|
expectedErr error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "No GPU",
|
name: "No GPU",
|
||||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||||
numGPU: -1,
|
numGPU: -1,
|
||||||
expected: ml.GPULayersList{},
|
expected: ml.GPULayersList{},
|
||||||
requireFull: true, // Should not try to evict even though we can't load any layers
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Full single GPU",
|
name: "Full single GPU",
|
||||||
|
|||||||
@@ -509,9 +509,11 @@ func GetVisibleDevicesEnv(l []DeviceInfo) map[string]string {
|
|||||||
// to crash at inference time and requires deeper validation before we include
|
// to crash at inference time and requires deeper validation before we include
|
||||||
// it in the supported devices list.
|
// it in the supported devices list.
|
||||||
func (d DeviceInfo) NeedsInitValidation() bool {
|
func (d DeviceInfo) NeedsInitValidation() bool {
|
||||||
// ROCm: rocblas will crash on unsupported devices.
|
// At this time the only library we know needs a 2nd pass is ROCm since
|
||||||
// CUDA: verify CC is supported by the version of the library
|
// rocblas will crash on unsupported devices. We want to find those crashes
|
||||||
return d.Library == "ROCm" || d.Library == "CUDA"
|
// during bootstrap discovery so we can eliminate those GPUs before the user
|
||||||
|
// tries to run inference on them
|
||||||
|
return d.Library == "ROCm"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the init validation environment variable
|
// Set the init validation environment variable
|
||||||
|
|||||||
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||||
positionsScale := m.getScale(ctx, batch.Positions)
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, positionsScale, batch.Outputs, batch, m.Cache), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -16,8 +16,6 @@ type TextOptions struct {
|
|||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
headDim, ropeDim int
|
headDim, ropeDim int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeOrigPosEmbeddings int
|
|
||||||
ropeScalingBeta float32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
@@ -36,7 +34,7 @@ type SelfAttention struct {
|
|||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||||
|
|
||||||
@@ -51,10 +49,6 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
|
|||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
if opts.ropeOrigPosEmbeddings > 0 {
|
|
||||||
q = q.Mul(ctx, positionsScale)
|
|
||||||
}
|
|
||||||
|
|
||||||
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
|
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||||
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||||
return sa.Output.Forward(ctx, kqv)
|
return sa.Output.Forward(ctx, kqv)
|
||||||
@@ -82,11 +76,11 @@ type Layer struct {
|
|||||||
MLP *MLP
|
MLP *MLP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||||
residual := hiddenState
|
residual := hiddenState
|
||||||
|
|
||||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, positionsScale, cache, opts)
|
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||||
|
|
||||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||||
// we need logits for.
|
// we need logits for.
|
||||||
@@ -103,7 +97,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
||||||
|
|
||||||
// image embeddings
|
// image embeddings
|
||||||
@@ -120,36 +114,25 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, o
|
|||||||
lastLayerOutputs = outputs
|
lastLayerOutputs = outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positions, positionsScale, lastLayerOutputs, cache, m.TextOptions)
|
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
return m.Output.Forward(ctx, hiddenState)
|
return m.Output.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) getScale(ctx ml.Context, positions []int32) ml.Tensor {
|
|
||||||
posScale := make([]float32, len(positions))
|
|
||||||
for n, pos := range positions {
|
|
||||||
interval := math.Floor(float64(pos) / float64(m.ropeOrigPosEmbeddings))
|
|
||||||
posScale[n] = float32(1.0 + float64(m.ropeScalingBeta)*math.Log(1.0+interval))
|
|
||||||
}
|
|
||||||
return ctx.Input().FromFloats(posScale, 1, 1, len(posScale))
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTextModel(c fs.Config) *TextModel {
|
func newTextModel(c fs.Config) *TextModel {
|
||||||
return &TextModel{
|
return &TextModel{
|
||||||
Layers: make([]Layer, c.Uint("block_count")),
|
Layers: make([]Layer, c.Uint("block_count")),
|
||||||
TextOptions: &TextOptions{
|
TextOptions: &TextOptions{
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
headDim: int(c.Uint("attention.key_length")),
|
headDim: int(c.Uint("attention.key_length")),
|
||||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
|
|
||||||
ropeScalingBeta: c.Float("rope.scaling_beta"),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,136 +0,0 @@
|
|||||||
package parsers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ministralParserState int
|
|
||||||
|
|
||||||
const (
|
|
||||||
ministralCollectingContent = iota
|
|
||||||
ministralCollectingThinkingContent
|
|
||||||
ministralCollectingToolName
|
|
||||||
ministralCollectingToolArgs
|
|
||||||
)
|
|
||||||
|
|
||||||
type MinistralParser struct {
|
|
||||||
state ministralParserState
|
|
||||||
buffer strings.Builder
|
|
||||||
tools []api.Tool
|
|
||||||
hasThinkingSupport bool
|
|
||||||
currentTool *api.Tool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *MinistralParser) HasToolSupport() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *MinistralParser) HasThinkingSupport() bool {
|
|
||||||
return p.hasThinkingSupport
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *MinistralParser) setInitialState(lastMessage *api.Message) {
|
|
||||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
|
||||||
if !p.HasThinkingSupport() {
|
|
||||||
p.state = ministralCollectingContent
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefill && lastMessage.Content != "" {
|
|
||||||
p.state = ministralCollectingContent
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
p.state = ministralCollectingThinkingContent
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *MinistralParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
|
||||||
p.tools = tools
|
|
||||||
p.setInitialState(lastMessage)
|
|
||||||
return tools
|
|
||||||
}
|
|
||||||
|
|
||||||
func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
|
|
||||||
for i := range tools {
|
|
||||||
if tools[i].Function.Name == n {
|
|
||||||
return &tools[i], nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("tool '%s' not found", n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
|
||||||
p.buffer.WriteString(s)
|
|
||||||
|
|
||||||
switch p.state {
|
|
||||||
case ministralCollectingContent:
|
|
||||||
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
|
|
||||||
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
|
|
||||||
if before != "" {
|
|
||||||
return before, "", calls, nil
|
|
||||||
}
|
|
||||||
p.state = ministralCollectingToolName
|
|
||||||
} else if strings.Contains(p.buffer.String(), "[THINK]") {
|
|
||||||
p.state = ministralCollectingThinkingContent
|
|
||||||
return "", "", calls, nil
|
|
||||||
} else {
|
|
||||||
p.buffer.Reset()
|
|
||||||
return s, "", calls, nil
|
|
||||||
}
|
|
||||||
case ministralCollectingThinkingContent:
|
|
||||||
if strings.Contains(p.buffer.String(), "[/THINK]") {
|
|
||||||
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
|
|
||||||
p.state = ministralCollectingContent
|
|
||||||
if after != "" {
|
|
||||||
p.buffer.Reset()
|
|
||||||
return after, thinkingContent, calls, nil
|
|
||||||
}
|
|
||||||
return "", thinkingContent, calls, nil
|
|
||||||
} else {
|
|
||||||
p.buffer.Reset()
|
|
||||||
return "", s, calls, nil
|
|
||||||
}
|
|
||||||
case ministralCollectingToolName:
|
|
||||||
if strings.Contains(p.buffer.String(), "[ARGS]") {
|
|
||||||
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
|
|
||||||
|
|
||||||
t, err := toolByName(p.tools, name)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", calls, err
|
|
||||||
}
|
|
||||||
p.currentTool = t
|
|
||||||
p.state = ministralCollectingToolArgs
|
|
||||||
return "", "", calls, nil
|
|
||||||
}
|
|
||||||
return "", "", calls, nil
|
|
||||||
case ministralCollectingToolArgs:
|
|
||||||
if strings.Contains(p.buffer.String(), "}") {
|
|
||||||
before, _ := splitAtTag(&p.buffer, "}", false)
|
|
||||||
before += "}"
|
|
||||||
|
|
||||||
var data map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(before), &data); err != nil {
|
|
||||||
// todo - throw a better error
|
|
||||||
return "", "", calls, err
|
|
||||||
}
|
|
||||||
|
|
||||||
p.state = ministralCollectingContent
|
|
||||||
|
|
||||||
call := api.ToolCall{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: p.currentTool.Function.Name,
|
|
||||||
Arguments: api.ToolCallFunctionArguments(data),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
calls = append(calls, call)
|
|
||||||
return "", "", calls, nil
|
|
||||||
}
|
|
||||||
return "", "", calls, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.buffer.String(), thinking, calls, nil
|
|
||||||
}
|
|
||||||
@@ -1,9 +1,6 @@
|
|||||||
package parsers
|
package parsers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
|
||||||
"unicode"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/harmony"
|
"github.com/ollama/ollama/harmony"
|
||||||
)
|
)
|
||||||
@@ -41,17 +38,16 @@ func ParserForName(name string) Parser {
|
|||||||
if parser, ok := registry.constructors[name]; ok {
|
if parser, ok := registry.constructors[name]; ok {
|
||||||
return parser()
|
return parser()
|
||||||
}
|
}
|
||||||
var p Parser
|
|
||||||
|
|
||||||
switch name {
|
switch name {
|
||||||
case "qwen3-coder":
|
case "qwen3-coder":
|
||||||
p = &Qwen3CoderParser{}
|
parser := &Qwen3CoderParser{}
|
||||||
|
return parser
|
||||||
case "qwen3-vl-instruct":
|
case "qwen3-vl-instruct":
|
||||||
p = &Qwen3VLParser{hasThinkingSupport: false}
|
parser := &Qwen3VLParser{hasThinkingSupport: false}
|
||||||
|
return parser
|
||||||
case "qwen3-vl-thinking":
|
case "qwen3-vl-thinking":
|
||||||
p = &Qwen3VLParser{hasThinkingSupport: true}
|
parser := &Qwen3VLParser{hasThinkingSupport: true}
|
||||||
case "ministral":
|
return parser
|
||||||
p = &MinistralParser{hasThinkingSupport: false}
|
|
||||||
case "passthrough":
|
case "passthrough":
|
||||||
return &PassthroughParser{}
|
return &PassthroughParser{}
|
||||||
case "harmony":
|
case "harmony":
|
||||||
@@ -61,7 +57,6 @@ func ParserForName(name string) Parser {
|
|||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return p
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PassthroughParser struct{}
|
type PassthroughParser struct{}
|
||||||
@@ -81,20 +76,3 @@ func (p *PassthroughParser) HasToolSupport() bool {
|
|||||||
func (p *PassthroughParser) HasThinkingSupport() bool {
|
func (p *PassthroughParser) HasThinkingSupport() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string) {
|
|
||||||
split := strings.SplitN(sb.String(), tag, 2)
|
|
||||||
if len(split) == 1 {
|
|
||||||
sb.Reset()
|
|
||||||
return split[0], ""
|
|
||||||
}
|
|
||||||
before := split[0]
|
|
||||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
|
||||||
after := split[1]
|
|
||||||
if trimAfter {
|
|
||||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
|
||||||
}
|
|
||||||
sb.Reset()
|
|
||||||
sb.WriteString(after)
|
|
||||||
return before, after // return events
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package parsers
|
package parsers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
@@ -96,164 +95,3 @@ func TestUnknownParserReturnsNil(t *testing.T) {
|
|||||||
t.Error("expected nil for unknown parser")
|
t.Error("expected nil for unknown parser")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSplitAtTag(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
tag string
|
|
||||||
trimAfter bool
|
|
||||||
wantBefore string
|
|
||||||
wantAfter string
|
|
||||||
wantSB string // expected content of strings.Builder after operation
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "basic split with trimAfter true",
|
|
||||||
input: "hello <!-- split --> world",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: true,
|
|
||||||
wantBefore: "hello",
|
|
||||||
wantAfter: "world",
|
|
||||||
wantSB: "world",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "basic split with trimAfter false",
|
|
||||||
input: "hello <!-- split --> world",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: false,
|
|
||||||
wantBefore: "hello",
|
|
||||||
wantAfter: " world",
|
|
||||||
wantSB: " world",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "tag at beginning with trimAfter true",
|
|
||||||
input: "<!-- split -->world",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: true,
|
|
||||||
wantBefore: "",
|
|
||||||
wantAfter: "world",
|
|
||||||
wantSB: "world",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "tag at beginning with trimAfter false",
|
|
||||||
input: "<!-- split --> world",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: false,
|
|
||||||
wantBefore: "",
|
|
||||||
wantAfter: " world",
|
|
||||||
wantSB: " world",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "tag at end with trimAfter true",
|
|
||||||
input: "hello <!-- split -->",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: true,
|
|
||||||
wantBefore: "hello",
|
|
||||||
wantAfter: "",
|
|
||||||
wantSB: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "tag at end with trimAfter false",
|
|
||||||
input: "hello <!-- split -->",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: false,
|
|
||||||
wantBefore: "hello",
|
|
||||||
wantAfter: "",
|
|
||||||
wantSB: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple tags splits at first occurrence",
|
|
||||||
input: "hello <!-- split --> world <!-- split --> end",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: true,
|
|
||||||
wantBefore: "hello",
|
|
||||||
wantAfter: "world <!-- split --> end",
|
|
||||||
wantSB: "world <!-- split --> end",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "tag not present",
|
|
||||||
input: "hello world",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: true,
|
|
||||||
wantBefore: "hello world",
|
|
||||||
wantAfter: "",
|
|
||||||
wantSB: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty input",
|
|
||||||
input: "",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: true,
|
|
||||||
wantBefore: "",
|
|
||||||
wantAfter: "",
|
|
||||||
wantSB: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "only whitespace before tag",
|
|
||||||
input: " \t\n<!-- split -->world",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: true,
|
|
||||||
wantBefore: "",
|
|
||||||
wantAfter: "world",
|
|
||||||
wantSB: "world",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "only whitespace after tag with trimAfter true",
|
|
||||||
input: "hello<!-- split --> \t\n",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: true,
|
|
||||||
wantBefore: "hello",
|
|
||||||
wantAfter: "",
|
|
||||||
wantSB: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "only whitespace after tag with trimAfter false",
|
|
||||||
input: "hello<!-- split --> \t\n",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: false,
|
|
||||||
wantBefore: "hello",
|
|
||||||
wantAfter: " \t\n",
|
|
||||||
wantSB: " \t\n",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "complex whitespace trimming",
|
|
||||||
input: " hello \t\n <!-- split --> \n\t world ",
|
|
||||||
tag: "<!-- split -->",
|
|
||||||
trimAfter: true,
|
|
||||||
wantBefore: " hello",
|
|
||||||
wantAfter: "world ",
|
|
||||||
wantSB: "world ",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "tag with special characters",
|
|
||||||
input: "text <tag attr=\"value\"> more text",
|
|
||||||
tag: "<tag attr=\"value\">",
|
|
||||||
trimAfter: true,
|
|
||||||
wantBefore: "text",
|
|
||||||
wantAfter: "more text",
|
|
||||||
wantSB: "more text",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
sb := &strings.Builder{}
|
|
||||||
sb.WriteString(tt.input)
|
|
||||||
|
|
||||||
before, after := splitAtTag(sb, tt.tag, tt.trimAfter)
|
|
||||||
|
|
||||||
// Check return values
|
|
||||||
if before != tt.wantBefore {
|
|
||||||
t.Errorf("splitAtTag() before = %q, want %q", before, tt.wantBefore)
|
|
||||||
}
|
|
||||||
if after != tt.wantAfter {
|
|
||||||
t.Errorf("splitAtTag() after = %q, want %q", after, tt.wantAfter)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check strings.Builder state
|
|
||||||
if sb.String() != tt.wantSB {
|
|
||||||
t.Errorf("strings.Builder after split = %q, want %q", sb.String(), tt.wantSB)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
|||||||
p.buffer.WriteString(s)
|
p.buffer.WriteString(s)
|
||||||
events := p.parseEvents()
|
events := p.parseEvents()
|
||||||
|
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
var contentSb strings.Builder
|
var contentSb strings.Builder
|
||||||
var thinkingSb strings.Builder
|
var thinkingSb strings.Builder
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
@@ -80,7 +81,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
|||||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
calls = append(calls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case qwenEventThinkingContent:
|
case qwenEventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
case qwenEventContent:
|
case qwenEventContent:
|
||||||
@@ -90,7 +91,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return contentSb.String(), thinkingSb.String(), calls, nil
|
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Qwen3VLParser) parseEvents() []qwenEvent {
|
func (p *Qwen3VLParser) parseEvents() []qwenEvent {
|
||||||
@@ -112,6 +113,19 @@ func (p *Qwen3VLParser) parseEvents() []qwenEvent {
|
|||||||
return all
|
return all
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func splitAtTag(p *Qwen3VLParser, tag string, trimAfter bool) (string, string) {
|
||||||
|
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||||
|
before := split[0]
|
||||||
|
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||||
|
after := split[1]
|
||||||
|
if trimAfter {
|
||||||
|
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||||
|
}
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(after)
|
||||||
|
return before, after // return events
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Qwen3VLParser) eatLeadingWhitespaceAndTransitionTo(nextState qwenParserState) ([]qwenEvent, bool) {
|
func (p *Qwen3VLParser) eatLeadingWhitespaceAndTransitionTo(nextState qwenParserState) ([]qwenEvent, bool) {
|
||||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||||
p.buffer.Reset()
|
p.buffer.Reset()
|
||||||
@@ -130,7 +144,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
|||||||
case CollectingContent:
|
case CollectingContent:
|
||||||
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
||||||
// events = emitContentBeforeTag(p, events, toolOpenTag)
|
// events = emitContentBeforeTag(p, events, toolOpenTag)
|
||||||
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
|
before, _ := splitAtTag(p, toolOpenTag, false)
|
||||||
if len(before) > 0 {
|
if len(before) > 0 {
|
||||||
events = append(events, qwenEventContent{content: before})
|
events = append(events, qwenEventContent{content: before})
|
||||||
}
|
}
|
||||||
@@ -181,7 +195,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
|||||||
}
|
}
|
||||||
case CollectingThinkingContent:
|
case CollectingThinkingContent:
|
||||||
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
||||||
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
|
thinking, remaining := splitAtTag(p, thinkingCloseTag, true)
|
||||||
if len(thinking) > 0 {
|
if len(thinking) > 0 {
|
||||||
events = append(events, qwenEventThinkingContent{content: thinking})
|
events = append(events, qwenEventThinkingContent{content: thinking})
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user