diff --git a/discover/runner.go b/discover/runner.go index 15fac2f17..d173959ad 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -123,13 +123,15 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev go func(i int) { defer wg.Done() var envVar string + if devices[i].Library == "ROCm" { if runtime.GOOS != "linux" { envVar = "HIP_VISIBLE_DEVICES" } else { envVar = "ROCR_VISIBLE_DEVICES" } + } else if devices[i].Library == "CUDA" { envVar = "CUDA_VISIBLE_DEVICES" - } else if devices[i].Library == "VULKAN" { + } else if devices[i].Library == "Vulkan" { envVar = "GGML_VK_VISIBLE_DEVICES" }