diff --git a/CMakePresets.json b/CMakePresets.json index 6181eb732..fd0fb9b3a 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -56,7 +56,7 @@ "name": "ROCm 6", "inherits": [ "ROCm" ], "cacheVariables": { - "AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" + "AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" } }, { diff --git a/README.md b/README.md index 60b23cded..47d0aebd9 100644 --- a/README.md +++ b/README.md @@ -392,6 +392,8 @@ See the [API documentation](./docs/api.md) for all endpoints. - [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool) - [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration) - [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) +- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance) +- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history ### Cloud diff --git a/discover/amd_linux.go b/discover/amd_linux.go index 830fa1df6..06e907391 100644 --- a/discover/amd_linux.go +++ b/discover/amd_linux.go @@ -279,12 +279,13 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { TotalMemory: totalMemory, FreeMemory: (totalMemory - usedMemory), }, - ID: ID, - Name: name, - Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), - MinimumMemory: rocmMinimumMemory, - DriverMajor: driverMajor, - DriverMinor: driverMinor, + ID: ID, + Name: name, + Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), + MinimumMemory: rocmMinimumMemory, + FlashAttention: true, // Supposedly ROCm supports it everywhere + DriverMajor: driverMajor, + DriverMinor: driverMinor, }, usedFilepath: usedFile, index: gpuID, diff --git a/discover/gpu.go b/discover/gpu.go index 2494469a7..c889c4833 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -310,6 +310,7 @@ func GetGPUInfo() GpuInfoList { C.free(unsafe.Pointer(memInfo.err)) continue } + gpuInfo.FlashAttention = driverMajor >= 7 gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) @@ -394,6 +395,7 @@ func GetGPUInfo() GpuInfoList { // TODO - convert this to MinimumMemory based on testing... var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. memInfo.free = C.uint64_t(totalFreeMem) + gpuInfo.FlashAttention = false gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) @@ -423,6 +425,7 @@ func GetGPUInfo() GpuInfoList { continue } + gpuInfo.FlashAttention = (C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 0) // 0 means supported gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c index e868dcc1b..29eaaeb7f 100644 --- a/discover/gpu_info_vulkan.c +++ b/discover/gpu_info_vulkan.c @@ -24,18 +24,32 @@ int check_perfmon(vk_handle_t* rh) { return 0; } -int support_memory_budget(vk_handle_t* rh, VkPhysicalDevice device) { +int is_extension_supported(vk_handle_t* rh, VkPhysicalDevice device, char* extension) { VkPhysicalDeviceProperties properties; (*rh->vkGetPhysicalDeviceProperties)(device, &properties); + uint32_t extensionCount; (*rh->vkEnumerateDeviceExtensionProperties)(device, NULL, &extensionCount, NULL); + + if (extensionCount == 0) { + return 0; + } + VkExtensionProperties* extensions = malloc(extensionCount * sizeof(VkExtensionProperties)); + if (extensions == NULL) { + return 0; + } + (*rh->vkEnumerateDeviceExtensionProperties)(device, NULL, &extensionCount, extensions); + for (int j = 0; j < extensionCount; j++) { - if (strcmp(extensions[j].extensionName, VK_EXT_MEMORY_BUDGET_EXTENSION_NAME) == 0) { + if (strcmp(extensions[j].extensionName, extension) == 0) { + free(extensions); return 1; } } + + free(extensions); return 0; } @@ -125,6 +139,7 @@ void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { } VkInstance instance; + VkApplicationInfo appInfo = {}; appInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; appInfo.pNext = NULL; @@ -133,6 +148,7 @@ void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { appInfo.pEngineName = "No Engine"; appInfo.engineVersion = VK_MAKE_VERSION(1, 0, 0); appInfo.apiVersion = VK_API_VERSION_1_2; + VkInstanceCreateInfo createInfo = {}; createInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; createInfo.pNext = NULL; @@ -141,6 +157,7 @@ void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { const char* extensions[] = { VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME }; createInfo.ppEnabledExtensionNames = extensions; createInfo.pApplicationInfo = &appInfo; + VkResult result = (*resp->ch.vkCreateInstance)(&createInfo, NULL, &instance); if (result != VK_SUCCESS) { resp->err = strdup("failed to create instance"); @@ -160,25 +177,63 @@ void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { resp->num_devices = deviceCount; } +int vk_check_flash_attention(vk_handle_t rh, int i) { + VkInstance instance = rh.vk; + uint32_t deviceCount = rh.num_devices; + + VkPhysicalDevice* devices = malloc(deviceCount * sizeof(VkPhysicalDevice)); + if (devices == NULL) { + return 0; + } + + VkResult result = (*rh.vkEnumeratePhysicalDevices)(instance, &deviceCount, devices); + if (result != VK_SUCCESS) { + free(devices); + return 0; + } + + VkPhysicalDeviceProperties properties; + (*rh.vkGetPhysicalDeviceProperties)(devices[i], &properties); + + int supports_nv_coopmat2 = is_extension_supported(&rh, devices[i], VK_NV_COOPERATIVE_MATRIX_2_EXTENSION_NAME); + if (!supports_nv_coopmat2) { + free(devices); + return 1; + } + + free(devices); + return 0; +} + void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { VkInstance instance = rh.vk; uint32_t deviceCount = rh.num_devices; VkPhysicalDevice* devices = malloc(deviceCount * sizeof(VkPhysicalDevice)); + if (devices == NULL) { + resp->err = strdup("memory allocation failed for devices array"); + return; + } + VkResult result = (*rh.vkEnumeratePhysicalDevices)(instance, &deviceCount, devices); if (result != VK_SUCCESS) { + free(devices); resp->err = strdup("failed to enumerate physical devices"); return; } VkPhysicalDeviceProperties properties; (*rh.vkGetPhysicalDeviceProperties)(devices[i], &properties); - int supports_budget = support_memory_budget(&rh, devices[i]); + + int supports_budget = is_extension_supported(&rh, devices[i], VK_EXT_MEMORY_BUDGET_EXTENSION_NAME); if (!supports_budget) { + free(devices); resp->err = strdup("device does not support memory budget"); return; } + if (properties.deviceType == VK_PHYSICAL_DEVICE_TYPE_CPU) { + free(devices); resp->err = strdup("device is a CPU"); return; } @@ -204,6 +259,8 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { } } + free(devices); + resp->err = NULL; snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; @@ -220,6 +277,7 @@ void vk_release(vk_handle_t rh) { (*rh.vkDestroyInstance)(rh.vk, NULL); UNLOAD_LIBRARY(rh.vk_handle); rh.vk_handle = NULL; + #ifdef __linux__ LOG(rh.verbose, "releasing libcap library\n"); UNLOAD_LIBRARY(rh.cap_handle); diff --git a/discover/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h index 6025f3e09..1f19be58e 100644 --- a/discover/gpu_info_vulkan.h +++ b/discover/gpu_info_vulkan.h @@ -60,6 +60,7 @@ typedef struct vk_init_resp void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp); void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp); +int vk_check_flash_attention(vk_handle_t rh, int i); void vk_release(vk_handle_t rh); #endif diff --git a/discover/types.go b/discover/types.go index 11a3acec3..b096b9e2e 100644 --- a/discover/types.go +++ b/discover/types.go @@ -36,9 +36,10 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"? UnreliableFreeMemory bool // GPU information - ID string `json:"gpu_id"` // string to use for selection of this specific GPU - Name string `json:"name"` // user friendly name if available - Compute string `json:"compute"` // Compute Capability or gfx + ID string `json:"gpu_id"` // string to use for selection of this specific GPU + Name string `json:"name"` // user friendly name if available + Compute string `json:"compute"` // Compute Capability or gfx + FlashAttention bool `json:"flash_attention"` // is flash attention supported // Driver Information - TODO no need to put this on each GPU DriverMajor int `json:"driver_major,omitempty"` @@ -178,11 +179,7 @@ func (si SystemInfo) GetOptimalThreadCount() int { // For each GPU, check if it does NOT support flash attention func (l GpuInfoList) FlashAttentionSupported() bool { for _, gpu := range l { - supportsFA := gpu.Library == "metal" || - (gpu.Library == "cuda" && gpu.DriverMajor >= 7) || - gpu.Library == "rocm" - - if !supportsFA { + if !gpu.FlashAttention { return false } } diff --git a/envconfig/config.go b/envconfig/config.go index cee40f6a8..53e358155 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -276,7 +276,7 @@ func AsMap() map[string]EnvVar { ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"} ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible by numeric ID"} ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible by UUID or numeric ID"} - ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which VK AMD devices are visible by numeric ID"} + ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"} ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"} ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"} ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"} diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 2237e7f51..7f2d61f09 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -314,18 +314,20 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { return fmt.Errorf("unassigned tensor: %s", t.Name) } - bts := make([]byte, t.Size()) - n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts) - if err != nil { - return err + bts := C.malloc(C.size_t(t.Size())) + if bts == nil { + return errors.New("failed to allocate tensor buffer") } + defer C.free(bts) - if n != len(bts) { - return errors.New("short read") + buf := unsafe.Slice((*byte)(bts), t.Size()) + n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf) + if err != nil || n != len(buf) { + return errors.New("read failed") } tensorSetMutex.Lock() - C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size())) + C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size())) tensorSetMutex.Unlock() return nil }) @@ -375,7 +377,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), C.int(len(schedBackends)), C.size_t(maxGraphNodes), - true, + C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)), ), input: deviceBufferTypes[input.d], output: deviceBufferTypes[output.d], diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index adcb3f738..cf5e6b911 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -89,7 +89,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) { +func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) { var slot *InputCacheSlot var numPast int32 var err error @@ -107,11 +107,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp return nil, nil, err } - // TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved? - if !cachePrompt { - numPast = 0 - } - slot.InUse = true slot.lastUsed = time.Now() diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 0a1b73f5a..f8925d119 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) { }) } } + +func TestLoadCacheSlot(t *testing.T) { + tests := []struct { + name string + cache InputCache + prompt []input.Input + wantErr bool + expectedSlotId int + expectedPrompt int // expected length of remaining prompt + }{ + { + name: "Basic cache hit - single user", + cache: InputCache{ + multiUserCache: false, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input.Input{}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + wantErr: false, + expectedSlotId: 0, + expectedPrompt: 1, // Only token 3 remains + }, + { + name: "Basic cache hit - multi user", + cache: InputCache{ + multiUserCache: true, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input.Input{}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + wantErr: false, + expectedSlotId: 0, + expectedPrompt: 1, // Only token 3 remains + }, + { + name: "Exact match - leave one input", + cache: InputCache{ + multiUserCache: false, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}}, + wantErr: false, + expectedSlotId: 0, + expectedPrompt: 1, // Should leave 1 token for sampling + }, + { + name: "No available slots", + cache: InputCache{ + multiUserCache: false, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: true, + lastUsed: time.Now().Add(-time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + wantErr: true, + expectedSlotId: -1, + expectedPrompt: -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt) + + // Check error state + if (err != nil) != tt.wantErr { + t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return // Skip further checks if we expected an error + } + + // Verify slot ID + if slot.Id != tt.expectedSlotId { + t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId) + } + + // Verify slot is now marked in use + if !slot.InUse { + t.Errorf("LoadCacheSlot() slot not marked InUse") + } + + // Verify remaining prompt length + if len(remainingPrompt) != tt.expectedPrompt { + t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v", + len(remainingPrompt), tt.expectedPrompt) + } + }) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index d4c24556c..9a1a549cd 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -115,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe params.numKeep = int32(len(inputs)) } + // TODO(jessegross): We should ensure that we always leave minBatch of context space to shift, + // otherwise we might truncate or split the batch against the model's wishes + // Ensure that at least 1 input can be discarded during shift params.numKeep = min(params.numKeep, s.cache.numCtx-1) @@ -366,17 +369,6 @@ func (s *Server) processBatch() error { batchSize := s.batchSize for j, inp := range seq.inputs { - if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { - if len(seq.pendingInputs) == 0 { - err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) - if err != nil { - return err - } - } else { - break - } - } - // If we are required to put following inputs into a single batch then extend the // batch size. Since we are only extending the size the minimum amount possible, this // will cause a break if we have pending inputs. @@ -389,6 +381,20 @@ func (s *Server) processBatch() error { break } + // If the sum of our working set (already processed tokens, tokens we added to this + // batch, required following tokens) exceeds the context size, then trigger a shift + // now so we don't have to do one later when we can't break the batch. + if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx { + if len(seq.pendingInputs) != 0 { + break + } + + err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) + if err != nil { + return err + } + } + options.Inputs = append(options.Inputs, inp.Token) if inp.Multimodal != nil { options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal}) @@ -590,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) diff --git a/sample/samplers.go b/sample/samplers.go index e302f9147..7c12da08b 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -87,8 +87,9 @@ func (s *Sampler) sample(tokens []token) (token, error) { // topK also sorts the tokens in descending order of logits tokens = topK(tokens, s.topK) - tokens = temperature(tokens, s.temperature) - tokens = softmax(tokens) + // scale and normalize the tokens in place + temperature(tokens, s.temperature) + softmax(tokens) tokens = topP(tokens, s.topP) tokens = minP(tokens, s.minP) diff --git a/sample/transforms.go b/sample/transforms.go index a5efa704e..3f677553f 100644 --- a/sample/transforms.go +++ b/sample/transforms.go @@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any { } // temperature applies scaling to the logits -func temperature(ts []token, temp float32) []token { +func temperature(ts []token, temp float32) { // Ensure temperature clipping near 0 to avoid numerical instability temp = max(temp, 1e-7) for i := range ts { ts[i].value = ts[i].value / temp } - return ts } // softmax applies normalization to the logits -func softmax(ts []token) []token { +func softmax(ts []token) { // Find max logit for numerical stability maxLogit := float32(math.Inf(-1)) for _, t := range ts { @@ -56,8 +55,6 @@ func softmax(ts []token) []token { for i := range ts { ts[i].value /= sum } - - return ts } // topK limits the number of tokens considered to the k highest logits @@ -99,6 +96,7 @@ func topK(ts []token, k int) []token { } // topP limits tokens to those with cumulative probability p +// requires ts to be sorted in descending order of probabilities func topP(ts []token, p float32) []token { if p == 1.0 { return ts @@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token { for i, t := range ts { sum += t.value if sum > float32(p) { - ts = ts[:i+1] - return ts + return ts[:i+1] } } return ts } -// minP limits tokens to those with cumulative probability p +// minP filters tokens with probabilities >= p * max_prob +// requires ts to be sorted in descending order of probabilities func minP(ts []token, p float32) []token { - if p == 1.0 { - return ts - } + maxProb := ts[0].value - maxProb := float32(math.Inf(-1)) - for _, token := range ts { - if token.value > maxProb { - maxProb = token.value + threshold := maxProb * p + + for i, t := range ts { + if t.value < threshold { + return ts[:i] } } - - threshold := maxProb * float32(p) - - // Filter tokens in-place - validTokens := ts[:0] - for i, token := range ts { - if token.value >= threshold { - validTokens = append(validTokens, ts[i]) - } - } - - ts = validTokens return ts } diff --git a/sample/transforms_test.go b/sample/transforms_test.go index 4880dd8f4..7faf30a55 100644 --- a/sample/transforms_test.go +++ b/sample/transforms_test.go @@ -34,17 +34,22 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) { func TestTemperature(t *testing.T) { input := []float32{1.0, 4.0, -2.0, 0.0} - got := temperature(toTokens(input), 0.5) + tokens := toTokens(input) + temperature(tokens, 0.5) want := []float32{2.0, 8.0, -4.0, 0.0} - compareLogits(t, "temperature(0.5)", want, got) + compareLogits(t, "temperature(0.5)", want, tokens) - got = temperature(toTokens(input), 1.0) + input = []float32{1.0, 4.0, -2.0, 0.0} + tokens = toTokens(input) + temperature(tokens, 1.0) want = []float32{1.0, 4.0, -2.0, 0.0} - compareLogits(t, "temperature(1)", want, got) + compareLogits(t, "temperature(1)", want, tokens) - got = temperature(toTokens(input), 0.0) + input = []float32{1.0, 4.0, -2.0, 0.0} + tokens = toTokens(input) + temperature(tokens, 0.0) want = []float32{1e7, 4e7, -2e7, 0.0} - compareLogits(t, "temperature(0)", want, got) + compareLogits(t, "temperature(0)", want, tokens) } func TestSoftmax(t *testing.T) { @@ -90,16 +95,17 @@ func TestSoftmax(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := softmax(toTokens(tt.input)) + tokens := toTokens(tt.input) + softmax(tokens) if tt.expected != nil { - compareLogits(t, tt.name, tt.expected, got) + compareLogits(t, tt.name, tt.expected, tokens) return } // Check probabilities sum to 1 var sum float32 - for _, token := range got { + for _, token := range tokens { sum += token.value if token.value < 0 || token.value > 1 { t.Errorf("probability out of range [0,1]: got %f", token.value) @@ -114,38 +120,44 @@ func TestSoftmax(t *testing.T) { func TestTopK(t *testing.T) { input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} - - // Test k=5 - got := topK(toTokens(input), 5) - if len(got) != 5 { - t.Errorf("topK(5): wrong length: want 5, got %d", len(got)) + tokens := toTokens(input) + tokens = topK(tokens, 5) + if len(tokens) != 5 { + t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens)) } - // Should keep highest 3 values in descending order want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154} - compareLogits(t, "topK(3)", want, got) + compareLogits(t, "topK(3)", want, tokens) - got = topK(toTokens(input), 20) - if len(got) != len(input) { - t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got)) + tokens = toTokens(input) + tokens = topK(tokens, 20) + if len(tokens) != len(input) { + t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens)) } - // Test k=-1 input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} - got = topK(toTokens(input), -1) - if len(got) != len(input) { - t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got)) + tokens = toTokens(input) + tokens = topK(tokens, -1) + if len(tokens) != len(input) { + t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens)) } - compareLogits(t, "topK(-1)", want, got) + compareLogits(t, "topK(-1)", want, tokens) - // Test k=0 input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} - got = topK(toTokens(input), 0) - if len(got) != len(input) { - t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got)) + tokens = toTokens(input) + tokens = topK(tokens, 0) + if len(tokens) != len(input) { + t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens)) + } + compareLogits(t, "topK(-1)", want, tokens) + + input = []float32{-1e7, -2e7, -3e7, -4e7} + tokens = toTokens(input) + tokens = topK(tokens, 1) + if len(tokens) < 1 { + t.Error("topK should keep at least one token") } - compareLogits(t, "topK(-1)", want, got) } func TestTopP(t *testing.T) { @@ -153,16 +165,25 @@ func TestTopP(t *testing.T) { tokens := toTokens(input) // First apply temperature and softmax to get probabilities - tokens = softmax(tokens) + softmax(tokens) tokens = topK(tokens, 20) // Then apply topP - got := topP(tokens, 0.95) + tokens = topP(tokens, 0.95) // Should keep tokens until cumsum > 0.95 - if len(got) > 3 { - t.Errorf("topP(0.95): kept too many tokens: got %d", len(got)) - t.Logf("got: %v", got) + if len(tokens) > 3 { + t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens)) + t.Logf("got: %v", tokens) + } + + // Test edge case - ensure at least one token remains + input = []float32{-1e6, -1e6, -1e6} // One dominant token + tokens = toTokens(input) + softmax(tokens) + tokens = topP(tokens, 0.0) // Very small p + if len(tokens) < 1 { + t.Error("topP should keep at least one token") } } @@ -171,14 +192,45 @@ func TestMinP(t *testing.T) { tokens := toTokens(input) // First apply temperature and softmax - tokens = softmax(tokens) + tokens = topK(tokens, 20) + softmax(tokens) - // Then apply minP - got := minP(tokens, 0.2) + tokens = minP(tokens, 1.0) + + if len(tokens) != 1 { + t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens)) + } + + // Test with normal p value + tokens = toTokens(input) // Reset tokens + tokens = topK(tokens, 20) + softmax(tokens) + tokens = minP(tokens, 0.2) // Should keep tokens with prob >= 0.2 * max_prob - if len(got) > 3 { - t.Errorf("minP(0.2): kept too many tokens: got %d", len(got)) + if len(tokens) > 3 { + t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens)) + t.Logf("got: %v", tokens) + } + + // Test with zero p value + tokens = toTokens(input) // Reset tokens + tokens = topK(tokens, 20) + softmax(tokens) + tokens = minP(tokens, 0.0) + + // Should keep only the highest probability token + if len(tokens) != len(input) { + t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens)) + t.Logf("got: %v", tokens) + } + + input = []float32{1e-10, 1e-10, 1e-10} + tokens = toTokens(input) + softmax(tokens) + tokens = minP(tokens, 1.0) + if len(tokens) < 1 { + t.Error("minP should keep at least one token even with extreme probabilities") } } @@ -231,7 +283,7 @@ func BenchmarkTransforms(b *testing.B) { b.ResetTimer() for b.Loop() { copy(tokensCopy, tokens) - topK(tokensCopy, 10) + tokens = topK(tokensCopy, 10) } }) @@ -239,7 +291,7 @@ func BenchmarkTransforms(b *testing.B) { b.ResetTimer() for b.Loop() { copy(tokensCopy, tokens) - topP(tokensCopy, 0.9) + tokens = topP(tokensCopy, 0.9) } }) @@ -247,7 +299,7 @@ func BenchmarkTransforms(b *testing.B) { b.ResetTimer() for b.Loop() { copy(tokensCopy, tokens) - minP(tokensCopy, 0.2) + tokens = minP(tokensCopy, 0.2) } }) @@ -255,7 +307,7 @@ func BenchmarkTransforms(b *testing.B) { b.ResetTimer() for b.Loop() { copy(tokensCopy, tokens) - topK(tokensCopy, 200000) + tokens = topK(tokensCopy, 200000) } }) }