diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3f7d5342a..e47cbad91 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12464,7 +12464,7 @@ static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) { return props.properties.deviceType; } -static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { +static std::string ggml_backend_vk_get_device_pci_id(int device_idx, int *domain_pci, int *bus_pci, int *device_pci) { GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; @@ -12496,24 +12496,17 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { const uint32_t pci_device = pci_bus_info.pciDevice; const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning + // Safely convert uint32_t to int (PCI IDs are small, so this is safe) + if (domain_pci) *domain_pci = static_cast(pci_bus_info.pciDomain); + if (bus_pci) *bus_pci = static_cast(pci_bus_info.pciBus); + if (device_pci) *device_pci = static_cast(pci_bus_info.pciDevice); + char pci_bus_id[16] = {}; snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); return std::string(pci_bus_id); } -static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { - if (id.empty()) return false; - unsigned int d = 0, b = 0, dev = 0, func = 0; - // Expected format: dddd:bb:dd.f (all hex) - int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); - if (n < 4) return false; - if (domain) *domain = (int) d; - if (bus) *bus = (int) b; - if (device) *device = (int) dev; - return true; -} - ////////////////////////// struct ggml_backend_vk_device_context { @@ -13023,18 +13016,11 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->name = GGML_VK_NAME + std::to_string(i); ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; - ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); - // Parse numeric PCI components if available - int d = 0, b = 0, devn = 0; - if (ggml_backend_vk_parse_pci_bus_id(ctx->pci_bus_id, &d, &b, &devn)) { - ctx->pciDomainID = d; - ctx->pciBusID = b; - ctx->pciDeviceID = devn; - } else { - ctx->pciDomainID = 0; - ctx->pciBusID = 0; - ctx->pciDeviceID = 0; - } + int domain = 0, bus = 0, device = 0; + ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i,&domain, &bus, &device); + ctx->pciDomainID = domain; + ctx->pciBusID = bus; + ctx->pciDeviceID = device; ctx->id = ggml_backend_vk_get_device_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i,