Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
9b58aecd3e llm: build with GGML_STATIC=on for static lib 2024-07-06 02:50:25 -04:00
105 changed files with 1057 additions and 1627 deletions

View File

@@ -147,7 +147,7 @@ jobs:
run: |
$ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP"

View File

@@ -169,7 +169,7 @@ jobs:
run: |
$ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP"

View File

@@ -84,9 +84,6 @@ type ChatRequest struct {
// Model is the model name, as in [GenerateRequest].
Model string `json:"model"`
// Template overrides the model's default prompt template.
Template string `json:"template"`
// Messages is the messages of the chat - can be used to keep a chat memory.
Messages []Message `json:"messages"`

View File

@@ -947,7 +947,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
req := &api.ChatRequest{
Model: opts.Model,
Template: opts.Template,
Messages: opts.Messages,
Format: opts.Format,
Options: opts.Options,

View File

@@ -18,7 +18,6 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
)
@@ -206,17 +205,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("Set system message.")
sb.Reset()
case MultilineTemplate:
mTemplate := sb.String()
sb.Reset()
_, err := template.Parse(mTemplate)
if err != nil {
multiline = MultilineNone
scanner.Prompt.UseAlt = false
fmt.Println("The template is invalid.")
continue
}
opts.Template = mTemplate
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
}
multiline = MultilineNone
@@ -378,15 +369,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("Set system message.")
sb.Reset()
} else if args[1] == "template" {
mTemplate := sb.String()
sb.Reset()
_, err := template.Parse(mTemplate)
if err != nil {
fmt.Println("The template is invalid.")
continue
}
opts.Template = mTemplate
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
}
sb.Reset()

View File

@@ -272,4 +272,4 @@ The following server settings may be used to adjust how Ollama handles concurren
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

3
go.mod
View File

@@ -18,7 +18,6 @@ require (
require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
@@ -72,7 +71,7 @@ require (
golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.20.0
golang.org/x/term v0.20.0
golang.org/x/text v0.15.0
golang.org/x/text v0.15.0 // indirect
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -49,17 +49,9 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
}
func commonAMDValidateLibDir() (string, error) {
// Favor our bundled version
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
// We try to favor system paths first, so that we can wire up the subprocess to use
// the system version. Only use our bundled version if the system version doesn't work
// This gives users a more recovery options if versions have subtle problems at runtime
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
@@ -95,5 +87,14 @@ func commonAMDValidateLibDir() (string, error) {
}
}
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
}

View File

@@ -84,8 +84,9 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
}
slog.Debug("hipDriverGetVersion", "version", version)
driverMajor = version / 10000000
driverMinor = (version - (driverMajor * 10000000)) / 100000
// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
driverMajor = version / 1000
driverMinor = (version - (driverMajor * 1000)) / 10
return driverMajor, driverMinor, nil
}

View File

@@ -22,8 +22,8 @@ const (
var (
// Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
)
func AMDGetGPUInfo() []RocmGPUInfo {
@@ -35,11 +35,12 @@ func AMDGetGPUInfo() []RocmGPUInfo {
}
defer hl.Release()
driverMajor, driverMinor, err := hl.AMDDriverVersion()
if err != nil {
// For now this is benign, but we may eventually need to fail compatibility checks
slog.Debug("error looking up amd driver version", "error", err)
}
// TODO - this reports incorrect version information, so omitting for now
// driverMajor, driverMinor, err := hl.AMDDriverVersion()
// if err != nil {
// // For now this is benign, but we may eventually need to fail compatibility checks
// slog.Debug("error looking up amd driver version", "error", err)
// }
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
count := hl.HipGetDeviceCount()
@@ -131,8 +132,10 @@ func AMDGetGPUInfo() []RocmGPUInfo {
MinimumMemory: rocmMinimumMemory,
Name: name,
Compute: gfx,
DriverMajor: driverMajor,
DriverMinor: driverMinor,
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
// DriverMajor: driverMajor,
// DriverMinor: driverMinor,
},
index: i,
}

View File

@@ -274,28 +274,6 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor
// query the management library as well so we can record any skew between the two
// which represents overhead on the GPU we must set aside on subsequent updates
if cHandles.nvml != nil {
C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
if memInfo.err != nil {
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
} else {
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
slog.Info("detected OS VRAM overhead",
"id", gpuInfo.ID,
"library", gpuInfo.Library,
"compute", gpuInfo.Compute,
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
"name", gpuInfo.Name,
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
)
}
}
}
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo)
}
@@ -396,14 +374,9 @@ func GetGPUInfo() GpuInfoList {
slog.Warn("error looking up nvidia GPU memory")
continue
}
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
// When using the management library update based on recorded overhead
memInfo.free -= C.uint64_t(gpu.OSOverhead)
}
slog.Debug("updating cuda memory data",
"gpu", gpu.ID,
"name", gpu.Name,
"overhead", format.HumanBytes2(gpu.OSOverhead),
slog.Group(
"before",
"total", format.HumanBytes2(gpu.TotalMemory),

View File

@@ -56,7 +56,7 @@ func GetCPUInfo() GpuInfoList {
func GetCPUMem() (memInfo, error) {
return memInfo{
TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: uint64(C.getFreeMemory()),
FreeMemory: 0,
}, nil
}

View File

@@ -2,4 +2,3 @@
#include <stdint.h>
uint64_t getRecommendedMaxVRAM();
uint64_t getPhysicalMemory();
uint64_t getFreeMemory();

View File

@@ -1,5 +1,4 @@
#import <Foundation/Foundation.h>
#import <mach/mach.h>
// go:build darwin
#include "gpu_info_darwin.h"
uint64_t getRecommendedMaxVRAM() {
@@ -9,27 +8,6 @@ uint64_t getRecommendedMaxVRAM() {
return result;
}
// getPhysicalMemory returns the total physical memory in bytes
uint64_t getPhysicalMemory() {
return [NSProcessInfo processInfo].physicalMemory;
}
// getFreeMemory returns the total free memory in bytes, including inactive
// memory that can be reclaimed by the system.
uint64_t getFreeMemory() {
mach_port_t host_port = mach_host_self();
mach_msg_type_number_t host_size = sizeof(vm_statistics64_data_t) / sizeof(integer_t);
vm_size_t pagesize;
vm_statistics64_data_t vm_stat;
host_page_size(host_port, &pagesize);
if (host_statistics64(host_port, HOST_VM_INFO64, (host_info64_t)&vm_stat, &host_size) != KERN_SUCCESS) {
return 0;
}
uint64_t free_memory = (uint64_t)vm_stat.free_count * pagesize;
free_memory += (uint64_t)vm_stat.speculative_count * pagesize;
free_memory += (uint64_t)vm_stat.inactive_count * pagesize;
return free_memory;
return [[NSProcessInfo processInfo] physicalMemory];
}

View File

@@ -52,8 +52,7 @@ type CPUInfo struct {
type CudaGPUInfo struct {
GpuInfo
OSOverhead uint64 // Memory overhead between the driver library and management library
index int //nolint:unused,nolintlint
index int //nolint:unused,nolintlint
}
type CudaGPUInfoList []CudaGPUInfo

View File

@@ -1413,7 +1413,7 @@ struct llama_server_context
return get_slot(-1);
}
LOG_DEBUG("slot with common prefix found", {{
LOG_INFO("slot with common prefix found", {{
"slot_id", slot->id,
"characters", longest
}});
@@ -1688,8 +1688,22 @@ struct llama_server_context
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
char buf[256];
llama_model_meta_val_str(model, "general.architecture", buf, 256);
bool gemma2 = strcmp(buf, "gemma2") == 0;
int32_t truncate_at = slot.n_ctx;
// truncate at 2/3 of the context length for gemma2 models
// as they do not support context shifts (from the sliding window implementation).
// this way, prompts that almost fit the context length can still generate a full
// response without a sudden stop from hitting the context limit
if (gemma2) {
truncate_at = 2 * slot.n_ctx / 3;
}
// if input prompt is too big, truncate it, if group attention self-extend is disabled
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
if (slot.ga_n == 1 && slot.n_prompt_tokens >= truncate_at)
{
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_shift = n_left / 2;
@@ -1717,6 +1731,19 @@ struct llama_server_context
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
// Models with sliding window attention do not work with context shifts, so
// limit their prediction to the context length
if (gemma2) {
int32_t limit = slot.n_ctx - slot.n_prompt_tokens;
slot.n_predict = limit;
slot.params.n_predict = limit;
LOG_INFO("model does not support sliding window, limiting generation", {
{"n_ctx", slot.n_ctx},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"n_predict", slot.n_predict}
});
}
if (!slot.params.cache_prompt)
{
llama_sampling_reset(slot.ctx_sampling);

View File

@@ -77,7 +77,7 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
init_vars
echo "OLLAMA_CUSTOM_CPU_DEFS=\"${OLLAMA_CUSTOM_CPU_DEFS}\""
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cpu"
echo "Building custom CPU"
build
@@ -93,7 +93,7 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
# -DGGML_AVX512_VBMI -- 2018 Intel Cannon Lake
# -DGGML_AVX512_VNNI -- 2021 Intel Alder Lake
COMMON_CPU_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_OPENMP=off"
COMMON_CPU_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_OPENMP=off"
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu" ]; then
#
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
fi
init_vars
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""

View File

@@ -6,9 +6,18 @@ function amdGPUs {
if ($env:AMDGPU_TARGETS) {
return $env:AMDGPU_TARGETS
}
# Current supported rocblas list from ROCm v6.1.2 on windows
# TODO - load from some common data file for linux + windows build consistency
$GPU_LIST = @(
"gfx900"
"gfx906:xnack-"
"gfx908:xnack-"
"gfx90a:xnack+"
"gfx90a:xnack-"
"gfx940"
"gfx941"
"gfx942"
"gfx1010"
"gfx1012"
"gfx1030"
"gfx1100"
"gfx1101"
@@ -195,6 +204,7 @@ function build_static() {
"-DCMAKE_C_COMPILER=gcc.exe",
"-DCMAKE_CXX_COMPILER=g++.exe",
"-DBUILD_SHARED_LIBS=off",
"-DGGML_STATIC=on",
"-DGGML_NATIVE=off",
"-DGGML_AVX=off",
"-DGGML_AVX2=off",
@@ -357,7 +367,6 @@ function build_rocm() {
"-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe",
"-DGGML_HIPBLAS=on",
"-DLLAMA_CUDA_NO_PEER_COPY=on",
"-DHIP_PLATFORM=amd",
"-DGGML_AVX=on",
"-DGGML_AVX2=off",
@@ -386,6 +395,7 @@ function build_rocm() {
sign
install
# Assumes v5.7, may need adjustments for v6
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"

View File

@@ -1,11 +1,11 @@
package llm
// #cgo CFLAGS: -Illama.cpp -Illama.cpp/include -Illama.cpp/ggml/include
// #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
// #cgo LDFLAGS: -lllama -lggml -lstdc++
// #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
// #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
// #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
// #cgo windows,arm64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
// #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
// #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
// #include <stdlib.h>

View File

@@ -1,11 +1,11 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 2b9ace28..172640e2 100644
index 73f52435..2b81b4bd 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -5357,16 +5357,7 @@ static void llm_load_vocab(
@@ -5092,16 +5092,7 @@ static void llm_load_vocab(
// for now, only BPE models have pre-tokenizers
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
vocab.tokenizer_add_space_prefix = false;
vocab.tokenizer_clean_spaces = true;
- if (tokenizer_pre.empty()) {
- LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__);
@@ -20,7 +20,7 @@ index 2b9ace28..172640e2 100644
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if (
tokenizer_pre == "llama3" ||
@@ -5439,7 +5430,8 @@ static void llm_load_vocab(
@@ -5164,7 +5155,8 @@ static void llm_load_vocab(
tokenizer_pre == "jais") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
} else {

View File

@@ -254,6 +254,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--tensor-split", estimate.TensorSplit)
}
if estimate.TensorSplit != "" {
params = append(params, "--tensor-split", estimate.TensorSplit)
}
for i := range len(servers) {
dir := availableServers[servers[i]]
if dir == "" {
@@ -675,7 +679,7 @@ type CompletionRequest struct {
Prompt string
Format string
Images []ImageData
Options *api.Options
Options api.Options
}
type CompletionResponse struct {
@@ -695,9 +699,10 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
defer s.sem.Release(1)
// put an upper limit on num_predict to avoid the model running on forever
// only allow maximum 10 "context shifts" to avoid infinite generation
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
req.Options.NumPredict = 10 * s.options.NumCtx
slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict)
}
request := map[string]any{

View File

@@ -338,16 +338,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []any:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
stops = append(stops, str)
} else {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
}
case []string:
options["stop"] = stop
default:
if r.Stop != nil {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
}
options["stop"] = stops
}
if r.MaxTokens != nil {

View File

@@ -3,6 +3,7 @@ package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -15,133 +16,7 @@ import (
"github.com/stretchr/testify/assert"
)
func TestMiddlewareRequests(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
Handler func() gin.HandlerFunc
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *http.Request)
}
var capturedRequest *http.Request
captureRequestMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
capturedRequest = c.Request
c.Next()
}
}
testCases := []testCase{
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
Handler: CompletionsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var genReq api.GenerateRequest
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
t.Fatal(err)
}
if genReq.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
}
if genReq.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
}
stopTokens, ok := genReq.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
},
},
}
gin.SetMode(gin.TestMode)
router := gin.New()
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
router = gin.New()
router.Use(captureRequestMiddleware())
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, endpoint)
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest)
})
}
}
func TestMiddlewareResponses(t *testing.T) {
func TestMiddleware(t *testing.T) {
type testCase struct {
Name string
Method string
@@ -155,7 +30,159 @@ func TestMiddlewareResponses(t *testing.T) {
testCases := []testCase{
{
Name: "completions handler error forwarding",
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
TestPath: "/api/chat",
Handler: ChatMiddleware,
Endpoint: func(c *gin.Context) {
var chatReq api.ChatRequest
if err := c.ShouldBindJSON(&chatReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
userMessage := chatReq.Messages[0].Content
var assistantMessage string
switch userMessage {
case "Hello":
assistantMessage = "Hello!"
default:
assistantMessage = "I'm not sure how to respond to that."
}
c.JSON(http.StatusOK, api.ChatResponse{
Message: api.Message{
Role: "assistant",
Content: assistantMessage,
},
})
},
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var chatResp ChatCompletion
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
t.Fatal(err)
}
if chatResp.Object != "chat.completion" {
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
}
if chatResp.Choices[0].Message.Content != "Hello!" {
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.GenerateResponse{
Response: "Hello!",
})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Hello!" {
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with params",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
var generateReq api.GenerateRequest
if err := c.ShouldBindJSON(&generateReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
temperature := generateReq.Options["temperature"].(float64)
var assistantMessage string
switch temperature {
case 1.6:
assistantMessage = "Received temperature of 1.6"
default:
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
}
c.JSON(http.StatusOK, api.GenerateResponse{
Response: assistantMessage,
})
},
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with error",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",

View File

@@ -107,12 +107,9 @@ function gatherDependencies() {
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
# currently works for Win11 + MSVC 2019 + Cuda V11
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
}
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"

View File

@@ -34,8 +34,6 @@ import (
"github.com/ollama/ollama/version"
)
var errCapabilityCompletion = errors.New("completion")
type Capability string
const CapabilityCompletion = Capability("completion")
@@ -64,10 +62,7 @@ type Model struct {
Template *template.Template
}
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
// any missing or unknown capabilities
func (m *Model) CheckCapabilities(caps ...Capability) error {
var errs []error
func (m *Model) Has(caps ...Capability) bool {
for _, cap := range caps {
switch cap {
case CapabilityCompletion:
@@ -86,19 +81,15 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
}
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
errs = append(errs, errCapabilityCompletion)
return false
}
default:
slog.Error("unknown capability", "capability", cap)
return fmt.Errorf("unknown capability: %s", cap)
return false
}
}
if err := errors.Join(errs...); err != nil {
return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
}
return nil
return true
}
func (m *Model) String() string {

View File

@@ -1,83 +1,217 @@
package server
import (
"bytes"
"context"
"fmt"
"log/slog"
"slices"
"strings"
"text/template/parse"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template"
)
type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
// pull out any system messages which should always be included in the prompt
var system []api.Message
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
if m.Role == "system" {
system = append(system, m)
return true
}
return false
})
if len(system) == 0 && m.System != "" {
// add model system prompt since it wasn't provided
system = append(system, api.Message{Role: "system", Content: m.System})
}
// always include the last message
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- {
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
return "", nil, err
}
s, err := tokenize(ctx, b.String())
if err != nil {
return "", nil, err
}
c := len(s)
if m.ProjectorPaths != nil {
for _, m := range msgs[i:] {
// images are represented as 768 sized embeddings
// TODO: get embedding length from project metadata
c += 768 * len(m.Images)
// isResponseNode checks if the node contains .Response
func isResponseNode(node *parse.ActionNode) bool {
for _, cmd := range node.Pipe.Cmds {
for _, arg := range cmd.Args {
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
if fieldNode.Ident[0] == "Response" {
return true
}
}
}
if c > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
} else {
n = i
}
}
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
return "", nil, err
}
for _, m := range msgs[n:] {
for _, i := range m.Images {
images = append(images, llm.ImageData{
ID: len(images),
Data: i,
})
}
}
return b.String(), images, nil
return false
}
// formatTemplateForResponse formats the template AST to:
// 1. remove all nodes after the first .Response (if generate=true)
// 2. add a .Response node to the end if it doesn't exist
// TODO(jmorganca): this should recursively cut the template before the first .Response
func formatTemplateForResponse(tmpl *template.Template, generate bool) {
var found bool
for i, node := range tmpl.Tree.Root.Nodes {
if actionNode, ok := node.(*parse.ActionNode); ok {
if isResponseNode(actionNode) {
found = true
if generate {
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
break
}
}
}
}
if !found {
// add the response node if it doesn't exist
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
}
}
// Prompt renders a prompt from a template. If generate is set to true,
// the response and parts of the template following it are not rendered
func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
formatTemplateForResponse(tmpl, generate)
vars := map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
}
var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
}
return sb.String(), nil
}
func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
rendered, err := Prompt(tmpl, system, prompt, response, false)
if err != nil {
return 0, err
}
tokens, err := encode(rendered)
if err != nil {
slog.Error("failed to encode prompt", "err", err)
return 0, err
}
return len(tokens), err
}
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
type prompt struct {
System string
Prompt string
Response string
images []int
tokens int
}
var p prompt
// iterate through messages to build up {system,user,response} prompts
var imgId int
var prompts []prompt
for _, msg := range messages {
switch strings.ToLower(msg.Role) {
case "system":
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.System = msg.Content
case "user":
if p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
var sb strings.Builder
for range msg.Images {
fmt.Fprintf(&sb, "[img-%d] ", imgId)
p.images = append(p.images, imgId)
imgId += 1
}
sb.WriteString(msg.Content)
p.Prompt = sb.String()
case "assistant":
if p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.Response = msg.Content
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// add final prompt
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
}
// calculate token lengths for each prompt, estimating 768 tokens per images
for i, p := range prompts {
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
if err != nil {
return "", err
}
prompts[i].tokens = tokens + len(prompts[i].images)*768
}
// truncate images and prompts starting from the beginning of the list
// until either one prompt remains or the total tokens fits the context window
// TODO (jmorganca): this doesn't account for the context window room required for the response
for {
var required int
for _, p := range prompts {
required += p.tokens
}
required += 1 // for bos token
if required <= window {
slog.Debug("prompt now fits in context window", "required", required, "window", window)
break
}
prompt := &prompts[0]
if len(prompt.images) > 1 {
img := prompt.images[0]
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
prompt.images = prompt.images[1:]
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
prompt.tokens -= 768
continue
}
if len(prompts) > 1 {
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
system := prompt.System
prompts = prompts[1:]
if system != "" && prompts[0].System == "" {
prompts[0].System = system
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
if err != nil {
return "", err
}
prompts[0].tokens = tokens + len(prompts[0].images)*768
}
continue
}
// stop truncating if there's only one prompt left
break
}
var sb strings.Builder
for i, p := range prompts {
// last prompt should leave the response unrendered (for completion)
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
if err != nil {
return "", err
}
sb.WriteString(rendered)
}
return sb.String(), nil
}

View File

@@ -1,8 +1,6 @@
package server
import (
"bytes"
"context"
"strings"
"testing"
@@ -10,195 +8,208 @@ import (
"github.com/ollama/ollama/template"
)
func tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) {
tokens = append(tokens, len(tokens))
}
return
}
func TestChatPrompt(t *testing.T) {
type expect struct {
prompt string
images [][]byte
}
cases := []struct {
name string
limit int
msgs []api.Message
expect
func TestPrompt(t *testing.T) {
tests := []struct {
name string
template string
system string
prompt string
response string
generate bool
want string
}{
{
name: "messages",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
name: "simple prompt",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "truncate messages",
limit: 1,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "A test. And a thumping good one at that, I'd wager. ",
},
name: "implicit response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
},
{
name: "truncate messages with image",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
},
expect: expect{
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
},
},
name: "response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
{
name: "truncate messages with images",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
},
name: "cut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
generate: true,
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
},
{
name: "messages with images",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
},
},
{
name: "message with image tag",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
},
},
{
name: "messages with interleaved images",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
},
},
{
name: "truncate message with interleaved images",
limit: 1024,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
},
},
{
name: "message with system prompt",
limit: 2048,
msgs: []api.Message{
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
},
name: "nocut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
},
}
tmpl, err := template.Parse(`
{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}{{ .Prompt }} {{ end }}
{{- if .Response }}{{ .Response }} {{ end }}`)
if err != nil {
t.Fatal(err)
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.Parse(tc.template)
if err != nil {
t.Fatal(err)
}
if tt.prompt != prompt {
t.Errorf("expected %q, got %q", tt.prompt, prompt)
got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
if err != nil {
t.Errorf("error = %v", err)
}
if len(images) != len(tt.images) {
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
}
for i := range images {
if images[i].ID != i {
t.Errorf("expected ID %d, got %d", i, images[i].ID)
}
if !bytes.Equal(images[i].Data, tt.images[i]) {
t.Errorf("expected %q, got %q", tt.images[i], images[i])
}
if got != tc.want {
t.Errorf("got = %v, want %v", got, tc.want)
}
})
}
}
func TestChatPrompt(t *testing.T) {
tests := []struct {
name string
template string
messages []api.Message
window int
want string
}{
{
name: "simple prompt",
template: "[INST] {{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
window: 1024,
want: "[INST] Hello [/INST]",
},
{
name: "with system message",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
},
{
name: "with response",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
},
{
name: "with implicit response",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
},
{
name: "with conversation",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "What are the potion ingredients?"},
{Role: "assistant", Content: "sugar"},
{Role: "user", Content: "Anything else?"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
},
{
name: "with truncation",
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
{Role: "user", Content: "Why is the sky blue?"},
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
},
window: 10,
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
},
{
name: "images",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
},
window: 1024,
want: "You are a Wizard. [img-0] Hello",
},
{
name: "images truncated",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
},
window: 1024,
want: "You are a Wizard. [img-0] [img-1] Hello",
},
{
name: "empty list",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{},
window: 1024,
want: "",
},
{
name: "empty prompt",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
messages: []api.Message{
{Role: "user", Content: ""},
},
window: 1024,
want: "",
},
}
encode := func(s string) ([]int, error) {
words := strings.Fields(s)
return make([]int, len(words)), nil
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.Parse(tc.template)
if err != nil {
t.Fatal(err)
}
got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
if err != nil {
t.Errorf("error = %v", err)
}
if got != tc.want {
t.Errorf("got: %q, want: %q", got, tc.want)
}
})
}

View File

@@ -1,13 +1,13 @@
package server
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"net"
"net/http"
@@ -54,8 +54,6 @@ func init() {
gin.SetMode(mode)
}
var errRequired = errors.New("is required")
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
@@ -69,147 +67,163 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
return opts, nil
}
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, mTemplate string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
if name == "" {
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
}
model, err := GetModel(name)
if err != nil {
return nil, nil, nil, err
}
if mTemplate != "" {
model.Template, err = template.Parse(mTemplate)
if err != nil {
return nil, nil, nil, err
}
}
if err := model.CheckCapabilities(caps...); err != nil {
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
opts, err := modelOptions(model, requestOpts)
if err != nil {
return nil, nil, nil, err
}
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err = <-errCh:
return nil, nil, nil, err
}
return runner.llama, model, &opts, nil
func isSupportedImageType(image []byte) bool {
contentType := http.DetectContentType(image)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
return slices.Contains(allowedTypes, contentType)
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Format != "" && req.Format != "json" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
}
caps := []Capability{CapabilityCompletion}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, "", caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
} else if err != nil {
handleScheduleError(c, req.Model, err)
for _, img := range req.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
}
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if req.Prompt == "" {
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
// an empty request loads the model
// note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
DoneReason: "load",
})
return
}
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
tmpl, err := template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt := req.Prompt
if !req.Raw {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
checkpointLoaded := time.Now()
var prompt string
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template == "" {
tmpl = model.Template
}
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
if req.System == "" {
req.System = model.System
}
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
slog.Debug("generate handler", "prompt", req.Prompt)
slog.Debug("generate handler", "template", req.Template)
slog.Debug("generate handler", "system", req.System)
tmpl := m.Template
if req.Template != "" {
tmpl, err = template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sb strings.Builder
for i := range req.Images {
fmt.Fprintf(&sb, "[img-%d] ", i)
}
var b bytes.Buffer
if req.Context != nil {
s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.WriteString(req.Prompt)
b.WriteString(s)
}
if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
p, err := Prompt(tmpl, req.System, sb.String(), "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt = b.String()
sb.Reset()
if req.Context != nil {
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.WriteString(prev)
}
sb.WriteString(p)
prompt = sb.String()
}
slog.Debug("generate request", "prompt", prompt, "images", images)
slog.Debug("generate handler", "prompt", prompt)
ch := make(chan any)
var generated strings.Builder
go func() {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
ch <- api.GenerateResponse{
fn := func(r llm.CompletionResponse) {
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: r.Content,
Done: r.Done,
Response: r.Content,
DoneReason: r.DoneReason,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
@@ -218,35 +232,77 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalDuration: r.EvalDuration,
},
}
}); err != nil {
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = append(req.Context, tokens...)
}
}
ch <- resp
}
var images []llm.ImageData
for i := range req.Images {
images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
})
}
// Start prediction
req := llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
var r api.GenerateResponse
// Accumulate responses into the final response
var final api.GenerateResponse
var sb strings.Builder
for rr := range ch {
switch t := rr.(type) {
for resp := range ch {
switch r := resp.(type) {
case api.GenerateResponse:
sb.WriteString(t.Response)
r = t
sb.WriteString(r.Response)
final = r
case gin.H:
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
return
}
}
r.Response = sb.String()
c.JSON(http.StatusOK, r)
final.Response = sb.String()
c.JSON(http.StatusOK, final)
return
}
@@ -255,17 +311,44 @@ func (s *Server) GenerateHandler(c *gin.Context) {
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, "", []Capability{}, req.Options, req.KeepAlive)
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
model, err := GetModel(req.Model)
if err != nil {
handleScheduleError(c, req.Model, err)
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
@@ -275,14 +358,17 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
resp := api.EmbeddingResponse{
Embedding: embedding,
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) PullModelHandler(c *gin.Context) {
@@ -563,9 +649,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
msgs := make([]api.Message, len(m.Messages))
for i, msg := range m.Messages {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
msgs := make([]api.Message, 0)
for _, msg := range m.Messages {
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
}
n := model.ParseName(req.Model)
@@ -1128,55 +1214,132 @@ func (s *Server) ProcessHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return runner.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
if err != nil {
return "", err
}
return prompt, nil
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
caps := []Capability{CapabilityCompletion}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, req.Template, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
} else if err != nil {
handleScheduleError(c, req.Model, err)
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
}
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant"},
Done: true,
DoneReason: "load",
})
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
slog.Debug("chat request", "images", len(images), "prompt", prompt)
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
checkpointLoaded := time.Now()
// if the first message is not a system message, then add the model's default system message
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
req.Messages = append([]api.Message{
{
Role: "system",
Content: model.System,
},
}, req.Messages...)
}
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// an empty request loads the model
if len(req.Messages) == 0 || prompt == "" {
resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
DoneReason: "load",
Message: api.Message{Role: "assistant"},
}
c.JSON(http.StatusOK, resp)
return
}
// only send images that are in the prompt
var i int
var images []llm.ImageData
for _, m := range req.Messages {
for _, img := range m.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
images = append(images, llm.ImageData{Data: img, ID: i})
}
i += 1
}
}
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
ch := make(chan any)
go func() {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
ch <- api.ChatResponse{
fn := func(r llm.CompletionResponse) {
resp := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
@@ -1189,52 +1352,64 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration,
},
}
}); err != nil {
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- resp
}
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
var r api.ChatResponse
// Accumulate responses into the final response
var final api.ChatResponse
var sb strings.Builder
for rr := range ch {
switch t := rr.(type) {
for resp := range ch {
switch r := resp.(type) {
case api.ChatResponse:
sb.WriteString(t.Message.Content)
r = t
sb.WriteString(r.Message.Content)
final = r
case gin.H:
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
return
}
}
r.Message.Content = sb.String()
c.JSON(http.StatusOK, r)
final.Message = api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, final)
return
}
streamResponse(c, ch)
}
func handleScheduleError(c *gin.Context, name string, err error) {
switch {
case errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, context.Canceled):
func handleErrorResponse(c *gin.Context, err error) {
if errors.Is(err, context.Canceled) {
c.JSON(499, gin.H{"error": "request canceled"})
case errors.Is(err, ErrMaxQueue):
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
case errors.Is(err, os.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if errors.Is(err, ErrMaxQueue) {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}

View File

@@ -545,9 +545,9 @@ func TestCreateDetectTemplate(t *testing.T) {
}
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
filepath.Join(p, "blobs", "sha256-9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b7"),
filepath.Join(p, "blobs", "sha256-b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce57cdac893db8139"),
})
})

View File

@@ -133,6 +133,10 @@ func (s *Scheduler) processPending(ctx context.Context) {
numParallel = 1
slog.Warn("multimodal models don't support parallel requests yet")
}
// Keep NumCtx and numParallel in sync
if numParallel > 1 {
pending.opts.NumCtx = pending.origNumCtx * numParallel
}
for {
cpus := s.getCpuFn()
@@ -193,36 +197,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
break
}
// Block attempting to load a model larger than system memory + GPU memory
estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
maxSize := systemMem.FreeMemory
// Add available GPU memory to the total pool
// macOS hardware has unified memory so don't double count
if runtime.GOOS != "darwin" {
for _, gpu := range gpus {
if gpu.Library == "cpu" {
continue
}
if loadedCount == 0 {
// If no other models are loaded, set the limit based on what's available
maxSize += gpu.FreeMemory
} else {
// Other models could be unloaded, favor total memory for limit
maxSize += gpu.TotalMemory
}
for _, gpu := range gpus {
if gpu.Library == "cpu" {
continue
}
if loadedCount == 0 {
// If no other models are loaded, set the limit based on what's available
maxSize += gpu.FreeMemory
} else {
// Other models could be unloaded, favor total memory for limit
maxSize += gpu.TotalMemory
}
}
// Block attempting to load a model larger than system memory + GPU memory
if estimate.TotalSize > maxSize {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
// Linux will crash if over-allocating memory - return an error to the user.
// TODO (jmorganca): add reasonable upper limits for darwin and windows as well
if runtime.GOOS == "linux" {
pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
break
}
pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
break
}
// Evaluate if the model will fit in the available system memory, or if we should unload a model first
@@ -230,10 +223,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
// simplifying assumption of defaultParallel when in CPU mode
if numParallel <= 0 {
numParallel = defaultParallel
pending.opts.NumCtx = pending.origNumCtx * numParallel
}
pending.opts.NumCtx = pending.origNumCtx * numParallel
if loadedCount == 0 {
slog.Debug("cpu mode with first model, loading")
s.loadFn(pending, ggml, gpus, numParallel)

View File

@@ -1,8 +1 @@
{{- if .Messages }}
{{- if .System }}<start_system>{{ .System }}<end_message>
{{- end }}
{{- range .Messages }}<start_{{ .Role }}>{{ .Content }}<end_message>
{{- end }}<start_assistant>
{{- else }}
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
{{- end }}
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>

View File

@@ -1,19 +1,7 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{- end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### Instruction:
{{- else if eq .Role "assistant" }}### Response:
{{- end }}
{{ .Content }}
{{ end }}### Response:
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction:
{{ .Prompt }}
{{ end }}### Response:
{{ .Response }}
{{- end }}
{{ .Response }}

View File

@@ -1,15 +1,6 @@
{{- if .Messages }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else }}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end }}
{{ .Response }}<|im_end|>

View File

@@ -1,17 +1,5 @@
{{- if .Messages }}
{{- if .System }}System: {{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}User:
{{- else if eq .Role "assistant" }}Assistant:
{{- end }} {{ .Content }}
{{ end }}Assistant:
{{- else }}
{{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
{{- end }}
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}

View File

@@ -1,13 +1,3 @@
{{- if .Messages }}
{{- if .System }}Source: system
{{ .System }} <step> {{ end }}
{{- range .Messages }}Source: {{ .Role }}
{{ .Content }} <step> {{ end }}Source: assistant
Destination: user
{{ else }}
{{ if .System }} Source: system
{{ .System }} <step>{{ end }} Source: user
@@ -15,5 +5,4 @@ Destination: user
{{ .Prompt }} <step> Source: assistant
Destination: user
{{ .Response }}<step>
{{- end }}
{{ .Response }}<step>

View File

@@ -1,13 +1,3 @@
{{- if .Messages }}
{{- if .System }}System: {{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}User:
{{ else if eq .Role "assistant" }}Falcon:
{{ end }}{{ .Content }}
{{ end }}Falcon:
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: {{ .Response }}
{{- end }}
{{ end }}Assistant: {{ .Response }}

View File

@@ -1,16 +1,4 @@
{{- if .Messages }}
{{- range $index, $_ := .Messages }}<start_of_turn>
{{- if eq .Role "user" }}user
{{- if and $.System (eq $index 0) }}
{{ $.System }}
{{- end }}
{{- else if eq .Role "assistant" }}model
{{- end }}
{{ .Content }}<end_of_turn>
{{ end }}<start_of_turn>model
{{ else }}
<start_of_turn>user
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
<start_of_turn>model
{{ .Response }}<end_of_turn>
{{- end }}
{{ .Response }}<end_of_turn>

View File

@@ -1,16 +1,3 @@
{{- if .Messages }}
{{- if .System }}System:
{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}Question:
{{- else if eq .Role "assistant" }}Answer:
{{- end }}
{{ .Content }}
{{ end }}Answer:
{{ else }}
{{ if .System }}
System:
{{ .System }}
@@ -19,5 +6,4 @@ System:
{{ .Prompt }}
{{ end }}Answer:
{{ .Response }}
{{- end }}
{{ .Response }}

View File

@@ -1,16 +1,3 @@
{{- if .Messages }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if eq $index 0 }}<<SYS>>
{{- if $.System }}
{{ $.System }}
{{ end }}<</SYS>>
{{ end }}{{ .Content }}
{{- else }} [/INST] {{ .Content }}</s><s>
{{- end }}
{{- end }} [/INST]
{{- else }}
[INST] <<SYS>>{{ .System }}<</SYS>>
{{ .Prompt }} [/INST] {{ .Response }}
{{- end }}
{{ .Prompt }} [/INST] {{ .Response }}

View File

@@ -1,19 +1,7 @@
{{- if .Messages }}
{{- if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
{{ .Content }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>
{{ else }}
{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|>
{{- end }}
{{ .Response }}<|eot_id|>

View File

@@ -1,20 +1,7 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}@@ Instruction
{{- else if eq .Role "assistant" }}@@ Response
{{- end }}
{{ .Content }}
{{ end }}@@ Response
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}@@ Instruction
{{ .Prompt }}
{{ end }}@@ Response
{{ .Response }}
{{- end }}
{{ .Response }}

View File

@@ -1,9 +1,6 @@
{{- if .Messages }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }}
{{ end }}{{ .Content }}
{{- else if eq .Role "assistant" }}[/INST] {{ .Content }}</s>
{{- end }}
{{- end }}[/INST]
{{- else }}[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST] {{ .Response }}
{{- end }}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>

View File

@@ -1,11 +1 @@
{{- if .Messages }}
{{- if .System }}GPT Correct System: {{ .System }}<|end_of_turn|>
{{- end }}
{{- range .Messages }}GPT Correct
{{- if eq .Role "user" }} User:
{{- else if eq .Role "assistant" }} Assistant:
{{- end }} {{ .Content }}<|end_of_turn|>
{{- end }}GPT Correct Assistant:
{{- else }}
{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
{{- end }}
{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>

View File

@@ -1,15 +1,6 @@
{{- if .Messages }}
{{- if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}
{{- range .Messages }}<|{{ .Role }}|>
{{ .Content }}<|end|>
{{ end }}<|assistant|>
{{ else }}
{{ if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}<|end|>
{{ end }}<|assistant|>
{{ .Response }}<|end|>
{{- end }}
{{ .Response }}<|end|>

View File

@@ -1,16 +1,3 @@
{{- if .Messages }}
{{- if .System }}### System:
{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### User:
{{ .Content }}
{{ else if eq .Role "assistant" }}### Assistant:
{{ .Content }}</s>
{{ end }}
{{ end }}### Assistant:
{{ else }}
{{ if .System }}### System:
{{ .System }}
@@ -18,5 +5,4 @@
{{ .Prompt }}
{{ end }}### Assistant:
{{ .Response }}
{{- end }}
{{ .Response }}

View File

@@ -1,17 +1,3 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### Instruction
{{ .Content }}
{{ else if eq .Role "assistant" }}### Response
{{ .Content }}<|endoftext|>
{{ end }}
{{- end }}### Response
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction
@@ -21,4 +7,3 @@
{{ end }}### Response
{{ .Response }}<|endoftext|>
{{- end }}

View File

@@ -5,7 +5,6 @@ import (
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"slices"
@@ -15,7 +14,6 @@ import (
"text/template/parse"
"github.com/agnivade/levenshtein"
"github.com/ollama/ollama/api"
"golang.org/x/exp/maps"
)
@@ -76,59 +74,30 @@ func Named(s string) (*named, error) {
return nil, errors.New("no matching template found")
}
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
type Template struct {
*template.Template
raw string
}
// response is a template node that can be added to templates that don't already have one
var response = parse.ActionNode{
NodeType: parse.NodeAction,
Pipe: &parse.PipeNode{
NodeType: parse.NodePipe,
Cmds: []*parse.CommandNode{
{
NodeType: parse.NodeCommand,
Args: []parse.Node{
&parse.FieldNode{
NodeType: parse.NodeField,
Ident: []string{"Response"},
},
},
},
},
},
}
func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero")
tmpl, err := tmpl.Parse(s)
if err != nil {
return nil, err
}
t := Template{Template: tmpl, raw: s}
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
// touch up the template and append {{ .Response }}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
}
return &t, nil
}
func (t *Template) String() string {
return t.raw
}
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
func Parse(s string) (*Template, error) {
t, err := template.New("").Option("missingkey=zero").Parse(s)
if err != nil {
return nil, err
}
return &Template{Template: t, raw: s}, nil
}
func (t *Template) Vars() []string {
var vars []string
for _, tt := range t.Templates() {
for _, n := range tt.Root.Nodes {
vars = append(vars, parseNode(n)...)
}
for _, n := range t.Tree.Root.Nodes {
vars = append(vars, parseNode(n)...)
}
set := make(map[string]struct{})
@@ -141,103 +110,6 @@ func (t *Template) Vars() []string {
return vars
}
type Values struct {
Messages []api.Message
}
func (t *Template) Execute(w io.Writer, v Values) error {
system, collated := collate(v.Messages)
if slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": collated,
})
}
var b bytes.Buffer
var prompt, response string
for i, m := range collated {
if m.Role == "user" {
prompt = m.Content
} else {
response = m.Content
}
if i != len(collated)-1 && prompt != "" && response != "" {
if err := t.Template.Execute(&b, map[string]any{
"System": "",
"Prompt": prompt,
"Response": response,
}); err != nil {
return err
}
prompt = ""
response = ""
}
}
var cut bool
tree := t.Template.Copy()
// for the last message, cut everything after "{{ .Response }}"
tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool {
if slices.Contains(parseNode(n), "Response") {
cut = true
}
return cut
})
if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
}); err != nil {
return err
}
_, err := io.Copy(w, &b)
return err
}
type messages []*api.Message
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system"
// which are templated separately. As a side effect, it mangles message content adding image
// tags ([img-%d]) as needed
func collate(msgs []api.Message) (system string, collated messages) {
var n int
for i := range msgs {
msg := msgs[i]
if msg.Role == "system" {
if system != "" {
system += "\n\n"
}
system += msg.Content
continue
}
for range msg.Images {
imageTag := fmt.Sprintf("[img-%d]", n)
if !strings.Contains(msg.Content, "[img]") {
msg.Content = strings.TrimSpace("[img] " + msg.Content)
}
msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
n++
}
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
collated[len(collated)-1].Content += "\n\n" + msg.Content
} else {
collated = append(collated, &msg)
}
}
return
}
func parseNode(n parse.Node) []string {
switch n := n.(type) {
case *parse.ActionNode:
@@ -280,8 +152,6 @@ func parseNode(n parse.Node) []string {
return names
case *parse.FieldNode:
return n.Ident
case *parse.TemplateNode:
return parseNode(n.Pipe)
}
return nil

View File

@@ -8,11 +8,9 @@ import (
"os"
"path/filepath"
"slices"
"strings"
"testing"
"text/template"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
@@ -48,7 +46,7 @@ func TestNamed(t *testing.T) {
t.Fatal(err)
}
tmpl, err := Parse(b.String())
tmpl, err := template.New(s).Parse(b.String())
if err != nil {
t.Fatal(err)
}
@@ -61,81 +59,18 @@ func TestNamed(t *testing.T) {
}
}
func TestTemplate(t *testing.T) {
cases := make(map[string][]api.Message)
for _, mm := range [][]api.Message{
{
{Role: "user", Content: "Hello, how are you?"},
},
{
{Role: "user", Content: "Hello, how are you?"},
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
{Role: "user", Content: "I'd like to show off how chat templating works!"},
},
{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
{Role: "user", Content: "I'd like to show off how chat templating works!"},
},
} {
var roles []string
for _, m := range mm {
roles = append(roles, m.Role)
}
cases[strings.Join(roles, "-")] = mm
}
matches, err := filepath.Glob("*.gotmpl")
if err != nil {
t.Fatal(err)
}
for _, match := range matches {
t.Run(match, func(t *testing.T) {
bts, err := os.ReadFile(match)
if err != nil {
t.Fatal(err)
}
tmpl, err := Parse(string(bts))
if err != nil {
t.Fatal(err)
}
for n, tt := range cases {
t.Run(n, func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
t.Fatal(err)
}
expect, err := os.ReadFile(filepath.Join("testdata", match, n))
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
})
}
}
func TestParse(t *testing.T) {
cases := []struct {
template string
vars []string
}{
{"{{ .Prompt }}", []string{"prompt", "response"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
{"{{ .Prompt }}", []string{"prompt"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
}
for _, tt := range cases {
@@ -152,159 +87,3 @@ func TestParse(t *testing.T) {
})
}
}
func TestExecuteWithMessages(t *testing.T) {
type template struct {
name string
template string
}
cases := []struct {
name string
templates []template
values Values
expected string
}{
{
"mistral",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
},
{
"mistral system",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant!"},
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
What is your name?[/INST] `,
},
{
"chatml",
[]template{
// this does not have a "no response" test because it's impossible to render the same output
{"response", `{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
`},
{"messages", `
{{- range $index, $_ := .Messages }}
{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>assistant
`},
},
Values{
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant!"},
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`<|im_start|>user
Hello friend!<|im_end|>
<|im_start|>assistant
Hello human!<|im_end|>
<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
What is your name?<|im_end|>
<|im_start|>assistant
`,
},
{
"moondream",
[]template{
// this does not have a "no response" test because it's impossible to render the same output
{"response", `{{ if .Prompt }}Question: {{ .Prompt }}
{{ end }}Answer: {{ .Response }}
`},
{"messages", `
{{- range .Messages }}
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
{{- end }}
{{- end }}Answer: `},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
{Role: "assistant", Content: "It's a hot dog."},
{Role: "user", Content: "What's in _this_ image?"},
{Role: "user", Images: []api.ImageData{[]byte("")}},
{Role: "user", Content: "Is it a hot dog?"},
},
},
`Question: [img-0] What's in this image?
Answer: It's a hot dog.
Question: What's in _this_ image?
[img-1]
Is it a hot dog?
Answer: `,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
for _, ttt := range tt.templates {
t.Run(ttt.name, func(t *testing.T) {
tmpl, err := Parse(ttt.template)
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, tt.values); err != nil {
t.Fatal(err)
}
if b.String() != tt.expected {
t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
}
})
}
})
}
}

View File

@@ -1 +0,0 @@
<start_system>You are a helpful assistant.<end_message><start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>

View File

@@ -1 +0,0 @@
<start_user>Hello, how are you?<end_message><start_assistant>

View File

@@ -1 +0,0 @@
<start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>

View File

@@ -1,10 +0,0 @@
You are a helpful assistant.### Instruction:
Hello, how are you?
### Response:
I'm doing great. How can I help you today?
### Instruction:
I'd like to show off how chat templating works!
### Response:

View File

@@ -1,4 +0,0 @@
### Instruction:
Hello, how are you?
### Response:

View File

@@ -1,10 +0,0 @@
### Instruction:
Hello, how are you?
### Response:
I'm doing great. How can I help you today?
### Instruction:
I'd like to show off how chat templating works!
### Response:

View File

@@ -1,9 +0,0 @@
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
I'm doing great. How can I help you today?<|im_end|>
<|im_start|>user
I'd like to show off how chat templating works!<|im_end|>
<|im_start|>assistant

View File

@@ -1,3 +0,0 @@
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant

View File

@@ -1,7 +0,0 @@
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
I'm doing great. How can I help you today?<|im_end|>
<|im_start|>user
I'd like to show off how chat templating works!<|im_end|>
<|im_start|>assistant

View File

@@ -1,9 +0,0 @@
System: You are a helpful assistant.
User: Hello, how are you?
Assistant: I'm doing great. How can I help you today?
User: I'd like to show off how chat templating works!
Assistant:

View File

@@ -1,3 +0,0 @@
User: Hello, how are you?
Assistant:

View File

@@ -1,7 +0,0 @@
User: Hello, how are you?
Assistant: I'm doing great. How can I help you today?
User: I'd like to show off how chat templating works!
Assistant:

View File

@@ -1,11 +0,0 @@
Source: system
You are a helpful assistant. <step> Source: user
Hello, how are you? <step> Source: assistant
I'm doing great. How can I help you today? <step> Source: user
I'd like to show off how chat templating works! <step> Source: assistant
Destination: user

View File

@@ -1,5 +0,0 @@
Source: user
Hello, how are you? <step> Source: assistant
Destination: user

View File

@@ -1,9 +0,0 @@
Source: user
Hello, how are you? <step> Source: assistant
I'm doing great. How can I help you today? <step> Source: user
I'd like to show off how chat templating works! <step> Source: assistant
Destination: user

View File

@@ -1,8 +0,0 @@
System: You are a helpful assistant.
User:
Hello, how are you?
Falcon:
I'm doing great. How can I help you today?
User:
I'd like to show off how chat templating works!
Falcon:

View File

@@ -1,3 +0,0 @@
User:
Hello, how are you?
Falcon:

View File

@@ -1,7 +0,0 @@
User:
Hello, how are you?
Falcon:
I'm doing great. How can I help you today?
User:
I'd like to show off how chat templating works!
Falcon:

View File

@@ -1,8 +0,0 @@
<start_of_turn>user
You are a helpful assistant.
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model

View File

@@ -1,3 +0,0 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model

View File

@@ -1,7 +0,0 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model

View File

@@ -1,13 +0,0 @@
System:
You are a helpful assistant.
Question:
Hello, how are you?
Answer:
I'm doing great. How can I help you today?
Question:
I'd like to show off how chat templating works!
Answer:

View File

@@ -1,4 +0,0 @@
Question:
Hello, how are you?
Answer:

View File

@@ -1,10 +0,0 @@
Question:
Hello, how are you?
Answer:
I'm doing great. How can I help you today?
Question:
I'd like to show off how chat templating works!
Answer:

View File

@@ -1,5 +0,0 @@
[INST] <<SYS>>
You are a helpful assistant.
<</SYS>>
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]

View File

@@ -1,3 +0,0 @@
[INST] <<SYS>><</SYS>>
Hello, how are you? [/INST]

View File

@@ -1,3 +0,0 @@
[INST] <<SYS>><</SYS>>
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]

View File

@@ -1,10 +0,0 @@
<|start_header_id|>system<|end_header_id|>
You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>
Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
I'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>
I'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,4 +0,0 @@
<|start_header_id|>user<|end_header_id|>
Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,8 +0,0 @@
<|start_header_id|>user<|end_header_id|>
Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
I'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>
I'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,12 +0,0 @@
You are a helpful assistant.
@@ Instruction
Hello, how are you?
@@ Response
I'm doing great. How can I help you today?
@@ Instruction
I'd like to show off how chat templating works!
@@ Response

View File

@@ -1,4 +0,0 @@
@@ Instruction
Hello, how are you?
@@ Response

View File

@@ -1,10 +0,0 @@
@@ Instruction
Hello, how are you?
@@ Response
I'm doing great. How can I help you today?
@@ Instruction
I'd like to show off how chat templating works!
@@ Response

View File

@@ -1,2 +0,0 @@
[INST] Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] You are a helpful assistant.
I'd like to show off how chat templating works![/INST]

View File

@@ -1 +0,0 @@
[INST] Hello, how are you?[/INST]

View File

@@ -1 +0,0 @@
[INST] Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works![/INST]

View File

@@ -1 +0,0 @@
GPT Correct System: You are a helpful assistant.<|end_of_turn|>GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:

View File

@@ -1 +0,0 @@
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant:

View File

@@ -1 +0,0 @@
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:

View File

@@ -1,9 +0,0 @@
<|system|>
You are a helpful assistant.<|end|>
<|user|>
Hello, how are you?<|end|>
<|assistant|>
I'm doing great. How can I help you today?<|end|>
<|user|>
I'd like to show off how chat templating works!<|end|>
<|assistant|>

View File

@@ -1,3 +0,0 @@
<|user|>
Hello, how are you?<|end|>
<|assistant|>

View File

@@ -1,7 +0,0 @@
<|user|>
Hello, how are you?<|end|>
<|assistant|>
I'm doing great. How can I help you today?<|end|>
<|user|>
I'd like to show off how chat templating works!<|end|>
<|assistant|>

View File

@@ -1,13 +0,0 @@
### System:
You are a helpful assistant.
### User:
Hello, how are you?
### Assistant:
I'm doing great. How can I help you today?</s>
### User:
I'd like to show off how chat templating works!
### Assistant:

View File

@@ -1,4 +0,0 @@
### User:
Hello, how are you?
### Assistant:

View File

@@ -1,10 +0,0 @@
### User:
Hello, how are you?
### Assistant:
I'm doing great. How can I help you today?</s>
### User:
I'd like to show off how chat templating works!
### Assistant:

View File

@@ -1,12 +0,0 @@
You are a helpful assistant.
### Instruction
Hello, how are you?
### Response
I'm doing great. How can I help you today?<|endoftext|>
### Instruction
I'd like to show off how chat templating works!
### Response

View File

@@ -1,4 +0,0 @@
### Instruction
Hello, how are you?
### Response

View File

@@ -1,10 +0,0 @@
### Instruction
Hello, how are you?
### Response
I'm doing great. How can I help you today?<|endoftext|>
### Instruction
I'd like to show off how chat templating works!
### Response

View File

@@ -1,6 +0,0 @@
You are a helpful assistant.
USER: Hello, how are you?
ASSISTANT: I'm doing great. How can I help you today?</s>
USER: I'd like to show off how chat templating works!
ASSISTANT:

View File

@@ -1,2 +0,0 @@
USER: Hello, how are you?
ASSISTANT:

View File

@@ -1,4 +0,0 @@
USER: Hello, how are you?
ASSISTANT: I'm doing great. How can I help you today?</s>
USER: I'd like to show off how chat templating works!
ASSISTANT:

Some files were not shown because too many files have changed in this diff Show More