diff --git a/common/common.cpp b/common/common.cpp index dbb724fbbbcff..db18e101f79b7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -832,9 +832,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa if (arg == "--main-gpu" || arg == "-mg") { CHECK_ARG params.main_gpu = std::stoi(argv[i]); -#ifndef GGML_USE_CUDA_SYCL_VULKAN +#if !defined(GGML_USE_CUDA_SYCL_VULKAN) || defined(GGML_USE_KOMPUTE) fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the main GPU has no effect.\n"); -#endif // GGML_USE_CUDA_SYCL_VULKAN +#endif // GGML_USE_CUDA_SYCL_VULKAN || GGML_USE_KOMPUTE return true; } if (arg == "--split-mode" || arg == "-sm") { diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 4a38eeb5c23bd..4a7a99296af8b 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -3,6 +3,9 @@ #include "ggml.h" #include "ggml-alloc.h" +#include +#include + #ifdef __cplusplus extern "C" { #endif diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h index d7903c666cebf..11bfbf2af7dde 100644 --- a/ggml/include/ggml-cuda.h +++ b/ggml/include/ggml-cuda.h @@ -3,6 +3,8 @@ #include "ggml.h" #include "ggml-backend.h" +#include + #ifdef GGML_USE_HIPBLAS #define GGML_CUDA_NAME "ROCm" #define GGML_CUBLAS_NAME "hipBLAS" @@ -11,11 +13,20 @@ #define GGML_CUBLAS_NAME "cuBLAS" #endif +#define GGML_CUDA_MAX_DEVICES 16 + #ifdef __cplusplus extern "C" { #endif -#define GGML_CUDA_MAX_DEVICES 16 +struct ggml_cuda_device { + uint32_t index; + uint64_t heapSize; + const char * name; +}; + +GGML_API GGML_CALL struct ggml_cuda_device * ggml_cuda_available_devices(size_t * count); +GGML_API GGML_CALL void ggml_cuda_device_destroy(ggml_cuda_device * device); // backend API GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device); diff --git a/ggml/include/ggml-kompute.h b/ggml/include/ggml-kompute.h index 171465456a5b1..cb9e93c67f0f2 100644 --- a/ggml/include/ggml-kompute.h +++ b/ggml/include/ggml-kompute.h @@ -22,11 +22,11 @@ struct ggml_vk_device { uint64_t maxAlloc; }; +void ggml_vk_device_destroy(struct ggml_vk_device * device); struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count); bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name); bool ggml_vk_has_vulkan(void); bool ggml_vk_has_device(void); -struct ggml_vk_device ggml_vk_current_device(void); // // backend API diff --git a/ggml/include/ggml-vulkan.h b/ggml/include/ggml-vulkan.h index af661c2d7d563..d40c02916ee00 100644 --- a/ggml/include/ggml-vulkan.h +++ b/ggml/include/ggml-vulkan.h @@ -10,12 +10,22 @@ extern "C" { #define GGML_VK_NAME "Vulkan" #define GGML_VK_MAX_DEVICES 16 -GGML_API void ggml_vk_instance_init(void); +struct ggml_vk_device { + uint32_t index; + int type; // same as VkPhysicalDeviceType + uint64_t heapSize; + const char * name; + uint32_t vendorID; +}; + +GGML_API GGML_CALL struct ggml_vk_device * ggml_vk_available_devices(size_t * count); +GGML_API GGML_CALL void ggml_vk_device_destroy(ggml_vk_device * device); // backend API GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num); GGML_API GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend); +GGML_API GGML_CALL size_t ggml_backend_vk_idx(ggml_backend_t backend); GGML_API GGML_CALL int ggml_backend_vk_get_device_count(void); GGML_API GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size); GGML_API GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total); diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 39e345b668bc1..73732c99eb6f4 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -34,19 +34,20 @@ #include #include #include +#include #include +#include #include #include -#include +#include +#include #include +#include #include #include #include -#include -#include -#include -#include #include +#include #include static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); @@ -186,6 +187,11 @@ static ggml_cuda_device_info ggml_cuda_init() { CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); GGML_CUDA_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); + info.devices[id].total_vram = prop.totalGlobalMem; + auto &name_dst = info.devices[id].name; + strncpy(name_dst, prop.name, sizeof name_dst); + name_dst[sizeof name_dst - 1] = 0; + info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; @@ -3068,3 +3074,57 @@ GGML_CALL int ggml_backend_cuda_reg_devices() { } return device_count; } + +static std::list ggml_cuda_available_devices_internal() { + std::list results; + + const auto & cuda_info = ggml_cuda_info(); + + std::unordered_map count_by_name; + + for (int dev_idx = 0; dev_idx < cuda_info.device_count; dev_idx++) { + const auto & device = cuda_info.devices[dev_idx]; + + std::string name(device.name); + size_t n_idx = ++count_by_name[name]; + if (n_idx > 1) { + name += " (" + std::to_string(n_idx) + ")"; + } + + results.push_back({ + /* index = */ uint32_t(dev_idx), + /* heapSize = */ uint64_t(device.total_vram), + /* name = */ strdup(name.c_str()), + }); + } + + // std::list::sort is guaranteed to be stable + results.sort( + [](const ggml_cuda_device & a, const ggml_cuda_device & b) -> bool { + return a.heapSize > b.heapSize; // descending + } + ); + + return results; +} + +// public API returns a C-style array +ggml_cuda_device * ggml_cuda_available_devices(size_t * count) { + auto devices = ggml_cuda_available_devices_internal(); + *count = devices.size(); + if (devices.empty()) { + return nullptr; + } + + size_t nbytes = sizeof(ggml_cuda_device) * devices.size(); + auto * arr = static_cast(malloc(nbytes)); + + int i = 0; + for (auto & d : devices) { arr[i++] = d; } + + return arr; +} + +void ggml_cuda_device_destroy(ggml_cuda_device * device) { + free(const_cast(device->name)); +} diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 26d9412a23eb6..2225b1fb1b7c5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -669,6 +669,7 @@ struct ggml_cuda_device_info { bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory size_t total_vram; + char name[256]; }; cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; diff --git a/ggml/src/ggml-kompute.cpp b/ggml/src/ggml-kompute.cpp index ed5f2e3494ba4..c3257f87be2a8 100644 --- a/ggml/src/ggml-kompute.cpp +++ b/ggml/src/ggml-kompute.cpp @@ -41,6 +41,7 @@ #include #include #include +#include #include #include #include @@ -62,6 +63,8 @@ typedef ggml_fp16_t half; +static const std::shared_ptr nullTensor = nullptr; + static std::string ggml_kompute_format_name(int device) { return "Kompute" + std::to_string(device); } @@ -82,26 +85,35 @@ struct ggml_kompute_context { static ggml_kompute_context *s_kompute_context = nullptr; class kompute_manager { - kp::Manager *s_mgr = nullptr; + std::unique_ptr s_mgr; public: kp::Manager *operator()() { - if (s_mgr && !s_mgr->hasInstance()) { - destroy(); + if (!s_mgr || !s_mgr->hasInstance()) { + s_mgr.reset(new kp::Manager); + } + return s_mgr.get(); + } + + void cleanup() { + if (s_mgr) { + s_mgr->clear(); } - if (!s_mgr) { - s_mgr = new kp::Manager; + } + + void freeDevice() { + if (s_mgr) { + s_mgr->destroy(); } - return s_mgr; } - void destroy() { - delete s_mgr; - s_mgr = nullptr; + void freeInstance() { + s_mgr.reset(); } }; static kompute_manager komputeManager; +static int global_device_ref = 0; struct ggml_vk_memory { void *data = nullptr; @@ -119,6 +131,61 @@ static void enable_sam() { } #endif +void ggml_vk_device_destroy(ggml_vk_device * device) { + free(const_cast(device->name)); +} + +struct ggml_vk_device_cpp: ggml_vk_device { + ggml_vk_device_cpp() = default; + + ggml_vk_device_cpp( + int index, int type, size_t heapSize, const char * name, const char * vendor, int subgroupSize, + uint64_t bufferAlignment, uint64_t maxAlloc + ) + : ggml_vk_device({ + /* index = */ index, + /* type = */ type, + /* heapSize = */ heapSize, + /* name = */ name ? strdup(name) : nullptr, + /* vendor = */ vendor, + /* supgroupSize = */ subgroupSize, + /* bufferAlignment = */ bufferAlignment, + /* maxAlloc = */ maxAlloc + }) + {} + + ggml_vk_device_cpp(ggml_vk_device_cpp && other) + : ggml_vk_device(other) + { + other.steal(); + } + + ggml_vk_device_cpp(ggml_vk_device_cpp & other) + : ggml_vk_device(other) + { + name = strdup(name); + } + + ggml_vk_device_cpp & operator=(ggml_vk_device_cpp && other) { + ggml_vk_device_destroy(this); + static_cast(*this) = other; + other.steal(); + return *this; + } + + ggml_vk_device_cpp & operator=(ggml_vk_device_cpp & other) { + ggml_vk_device_destroy(this); + static_cast(*this) = other; + name = strdup(name); + return *this; + } + + ~ggml_vk_device_cpp() { ggml_vk_device_destroy(this); } + + // release the reference so we don't double-free name + void steal() { name = nullptr; } +}; + static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_device) { vk::PhysicalDeviceFeatures availableFeatures; physical_device.getFeatures(&availableFeatures); @@ -144,7 +211,6 @@ static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_devi if (!availableFeatures12.storageBuffer8BitAccess || !availableFeatures12.uniformAndStorageBuffer8BitAccess || - !availableFeatures12.shaderFloat16 || !availableFeatures12.shaderInt8) { return false; } @@ -165,8 +231,8 @@ static const char * ggml_vk_getVendorName(uint32_t vendorID) { } } -static std::vector ggml_vk_available_devices_internal(size_t memoryRequired) { - std::vector results; +static std::list ggml_vk_available_devices_internal(size_t memoryRequired) { + std::list results; if (!komputeManager()->hasVulkan() || !komputeManager()->hasInstance()) return results; @@ -188,15 +254,19 @@ static std::vector ggml_vk_available_devices_internal(size_t mem const auto & physical_device = physical_devices[i]; VkPhysicalDeviceProperties dev_props = physical_device.getProperties(); - VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties(); const uint32_t major = VK_VERSION_MAJOR(dev_props.apiVersion); const uint32_t minor = VK_VERSION_MINOR(dev_props.apiVersion); if (major < 1 || minor < 2) continue; + if (dev_props.vendorID == 0x8086) + continue; // Intel GPUs not supported + if (!ggml_vk_checkPhysicalDeviceFeatures(physical_device)) continue; + VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties(); + size_t heapSize = 0; for (uint32_t j = 0; j < memoryProperties.memoryHeapCount; ++j) { VkMemoryHeap heap = memoryProperties.memoryHeaps[j]; @@ -233,32 +303,32 @@ static std::vector ggml_vk_available_devices_internal(size_t mem if (subgroup_props.subgroupSize < 32) continue; - ggml_vk_device d; - d.index = i; - d.type = dev_props.deviceType; - d.heapSize = heapSize; - d.vendor = strdup(ggml_vk_getVendorName(dev_props.vendorID)); - d.subgroupSize = subgroup_props.subgroupSize; - d.bufferAlignment = dev_props.limits.minStorageBufferOffsetAlignment; - - if (has_maintenance4) { - d.maxAlloc = std::min(dev_props3.maxMemoryAllocationSize, dev_props4.maxBufferSize); - } else { - d.maxAlloc = dev_props3.maxMemoryAllocationSize; - } - std::string name(dev_props.deviceName); size_t n_idx = ++count_by_name[name]; if (n_idx > 1) { name += " (" + std::to_string(n_idx) + ")"; } - d.name = strdup(name.c_str()); - results.push_back(d); + uint64_t maxAlloc = dev_props3.maxMemoryAllocationSize; + if (has_maintenance4) { + maxAlloc = std::min(maxAlloc, dev_props4.maxBufferSize); + } + + results.emplace_back( + /* index = */ i, + /* type = */ dev_props.deviceType, + /* heapSize = */ heapSize, + /* name = */ name.c_str(), + /* vendor = */ ggml_vk_getVendorName(dev_props.vendorID), + /* subgroupSize = */ subgroup_props.subgroupSize, + /* bufferAlignment = */ dev_props.limits.minStorageBufferOffsetAlignment, + /* maxAlloc = */ maxAlloc + ); } - std::stable_sort(results.begin(), results.end(), - [](const ggml_vk_device& lhs, const ggml_vk_device& rhs) -> bool { + // std::list::sort is guaranteed to be stable + results.sort( + [](const ggml_vk_device_cpp & lhs, const ggml_vk_device_cpp & rhs) -> bool { if (lhs.type != rhs.type) { if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return true; if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return false; @@ -266,7 +336,7 @@ static std::vector ggml_vk_available_devices_internal(size_t mem if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return true; if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return false; } - return lhs.heapSize < rhs.heapSize; + return lhs.heapSize > rhs.heapSize; // most VRAM first } ); @@ -283,52 +353,53 @@ ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count size_t nbytes = sizeof (ggml_vk_device) * (devices.size()); auto * arr = static_cast(malloc(nbytes)); - memcpy(arr, devices.data(), nbytes); + + int i = 0; + for (auto & d : devices) { arr[i++] = d; d.steal(); } + return arr; } -static void ggml_vk_filterByVendor(std::vector& devices, const std::string& targetVendor) { +static void ggml_vk_filterByVendor(std::list & devices, const std::string & targetVendor) { devices.erase( std::remove_if(devices.begin(), devices.end(), - [&targetVendor](const ggml_vk_device& device) { + [&targetVendor](const ggml_vk_device_cpp & device) { return device.vendor != targetVendor; }), devices.end() ); } -static void ggml_vk_filterByName(std::vector& devices, const std::string& targetName) { +static void ggml_vk_filterByName(std::list & devices, const std::string & targetName) { devices.erase( std::remove_if(devices.begin(), devices.end(), - [&targetName](const ggml_vk_device& device) { + [&targetName](const ggml_vk_device_cpp & device) { return device.name != targetName; }), devices.end() ); } -static bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const std::string & name) { - if (name.empty()) +bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const char * name) { + if (!*name) return false; auto devices = ggml_vk_available_devices_internal(memoryRequired); - if (name == "amd" || name == "nvidia" || name == "intel") { - ggml_vk_filterByVendor(devices, name); - } else if (name != "gpu") { - ggml_vk_filterByName(devices, name); + std::string name_str(name); + if (name_str == "amd" || name_str == "nvidia" || name_str == "intel") { + ggml_vk_filterByVendor(devices, name_str); + } else if (name_str != "gpu") { + ggml_vk_filterByName(devices, name_str); } if (devices.empty()) return false; *device = devices.front(); + devices.front().steal(); return true; } -bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const char * name) { - return ggml_vk_get_device(device, memoryRequired, std::string(name)); -} - bool ggml_vk_has_vulkan() { return komputeManager()->hasVulkan(); } @@ -337,12 +408,13 @@ bool ggml_vk_has_device() { return komputeManager()->hasDevice(); } -ggml_vk_device ggml_vk_current_device() { +static ggml_vk_device_cpp ggml_vk_current_device() { if (!komputeManager()->hasDevice()) - return ggml_vk_device(); + return ggml_vk_device_cpp(); auto devices = ggml_vk_available_devices_internal(0); ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data()); + GGML_ASSERT(!devices.empty()); return devices.front(); } @@ -432,7 +504,7 @@ vk::DeviceMemory *ggml_vk_allocate(size_t size, vk::MemoryPropertyFlags flags, v vk::Result r = komputeManager()->device()->allocateMemory(&allocInfo, nullptr, vkDeviceMemory); if (r != vk::Result::eSuccess) { std::cerr << "Error allocating memory " << vk::to_string(r) << std::endl; - throw std::runtime_error("Error allocating vulkan memory."); + return nullptr; } return vkDeviceMemory; } @@ -454,9 +526,13 @@ static ggml_vk_memory ggml_vk_allocate(size_t size) { bool isHostVisible = false; { memory.primaryBuffer = ggml_vk_allocate_buffer(size); + vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.primaryBuffer); vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eDeviceLocal; memory.primaryMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible); + if (!memory.primaryMemory) + return {}; + komputeManager()->device()->bindBufferMemory(*memory.primaryBuffer, *memory.primaryMemory, 0); if (isHostVisible) { vk::Result r = komputeManager()->device()->mapMemory(*memory.primaryMemory, 0, size, vk::MemoryMapFlags(), &memory.data); @@ -500,6 +576,11 @@ static void ggml_vk_free_memory(ggml_vk_memory &memory) *memory.stagingMemory, (vk::Optional)nullptr); } + + delete memory.primaryMemory; + delete memory.primaryBuffer; + delete memory.stagingMemory; + delete memory.stagingBuffer; } static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft); @@ -522,31 +603,47 @@ ggml_vk_memory * ggml_vk_find_tensor(const struct ggml_tensor * t, uint64_t & of } static -const std::shared_ptr ggml_vk_get_tensor(const struct ggml_tensor * t, uint32_t * alignedOffset = nullptr) { - uint64_t originalOffset = 0; - auto * res = ggml_vk_find_tensor(t, originalOffset); +const std::shared_ptr ggml_vk_get_tensor_aligned(const struct ggml_tensor * t, uint32_t * aligned_offset) { + uint64_t original_offset = 0; + auto * res = ggml_vk_find_tensor(t, original_offset); if (!res) { - static std::shared_ptr nullTensor = nullptr; return nullTensor; } // Create a tensor whose memory will be composed of our buffers at the correct offset - const size_t nelements = ggml_nelements(t); size_t nbytes = ggml_nbytes(t); + size_t vulkan_offset = ggml_vk_aligned_offset(t->buffer, original_offset); + *aligned_offset = original_offset - vulkan_offset; + nbytes += *aligned_offset; + + return komputeManager()->tensor( + t->data, + ggml_nelements(t), nbytes, + kp::Tensor::TensorDataTypes::eFloat, + res->primaryMemory, res->primaryBuffer, + res->stagingMemory, res->stagingBuffer, + vulkan_offset); +} - size_t vulkanOffset = ggml_vk_aligned_offset(t->buffer, originalOffset); - if (alignedOffset) { - *alignedOffset = originalOffset - vulkanOffset; - nbytes += *alignedOffset; +static +const std::shared_ptr ggml_vk_get_tensor_slice(const struct ggml_tensor * t, size_t offset, size_t nbytes) { + uint64_t tensor_offset = 0; + auto * res = ggml_vk_find_tensor(t, tensor_offset); + if (!res) { + return nullTensor; } + size_t elsz = ggml_element_size(t); + GGML_ASSERT(nbytes % elsz == 0); + + // Create a tensor whose memory will be composed of our buffers at the correct offset return komputeManager()->tensor( - t->data, - nelements, - nbytes, kp::Tensor::TensorDataTypes::eFloat, + reinterpret_cast(t->data) + offset, + nbytes / elsz, nbytes, + kp::Tensor::TensorDataTypes::eFloat, res->primaryMemory, res->primaryBuffer, res->stagingMemory, res->stagingBuffer, - vulkanOffset); + tensor_offset + offset); } static std::vector getSpirvShader(const unsigned char* rawData, size_t size) { @@ -1324,6 +1421,15 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) { } static bool ggml_vk_supports_op(const struct ggml_tensor * op) { + switch (op->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + return true; // noop -> dst type does not matter + } + switch (op->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: @@ -1331,36 +1437,35 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) { case GGML_TYPE_Q4_1: break; default: - return false; + return false; // dst type not supported } switch (op->op) { + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_SCALE: + return true; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: - return ggml_is_contiguous(op->src[0]); + return ggml_nelements(op) % 4 == 0 && ggml_is_contiguous(op->src[0]); default: ; } break; - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - case GGML_OP_ADD: - case GGML_OP_MUL: - case GGML_OP_SCALE: case GGML_OP_SOFT_MAX: - case GGML_OP_RMS_NORM: - case GGML_OP_NORM: + float max_bias; + memcpy(&max_bias, (const float *)op->op_params + 1, sizeof(float)); + return max_bias == 0.0f; case GGML_OP_ROPE: - return true; - case GGML_OP_DUP: - case GGML_OP_CPY: + return op->src[2] == nullptr; case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_DUP: switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: @@ -1497,13 +1602,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; - const static std::shared_ptr nullTensor = nullptr; uint32_t off_src0 = 0; uint32_t off_src1 = 0; uint32_t off_dst = 0; - const std::shared_ptr& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor; - const std::shared_ptr& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor; - const std::shared_ptr& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor; + std::shared_ptr id_src0 = src0 ? ggml_vk_get_tensor_aligned(src0, &off_src0) : nullTensor; + std::shared_ptr id_src1 = src1 ? ggml_vk_get_tensor_aligned(src1, &off_src1) : nullTensor; + std::shared_ptr id_dst = dst ? ggml_vk_get_tensor_aligned(dst, &off_dst) : nullTensor; switch (dst->op) { case GGML_OP_ADD: @@ -1793,30 +1897,44 @@ struct ggml_backend_kompute_buffer_type_context { }; static void ggml_backend_kompute_device_ref(ggml_backend_buffer_type_t buft) { + static int s_cur_device = -1; + auto * ctx = static_cast(buft->context); + auto * mgr = komputeManager(); - if (!ctx->device_ref) { - komputeManager()->initializeDevice( - ctx->device, {}, { - "VK_KHR_shader_float16_int8", "VK_KHR_8bit_storage", - "VK_KHR_16bit_storage", "VK_KHR_shader_non_semantic_info" - } + if (!ctx->device_ref && (!mgr->hasDevice() || s_cur_device != ctx->device)) { + assert(!global_device_ref); + if (mgr->hasDevice()) { + komputeManager.freeDevice(); + } + mgr->initializeDevice( + ctx->device, {}, {"VK_KHR_8bit_storage", "VK_KHR_16bit_storage", "VK_KHR_shader_non_semantic_info"} ); + s_cur_device = ctx->device; } - assert(ggml_vk_has_device()); + assert(mgr->hasDevice()); ctx->device_ref++; + global_device_ref++; } static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) { auto * ctx = static_cast(buft->context); - assert(ctx->device_ref > 0); + assert(ctx->device_ref > 0); + assert(global_device_ref > 0); ctx->device_ref--; + global_device_ref--; if (!ctx->device_ref) { - komputeManager.destroy(); + assert(!global_device_ref); + // free device memory + komputeManager.cleanup(); + if (!s_kompute_context) { + // ggml_backend_kompute_free was previously called, we can now fully cleanup Vulkan + komputeManager.freeInstance(); + } } } @@ -1827,10 +1945,9 @@ static const char * ggml_backend_kompute_buffer_get_name(ggml_backend_buffer_t b static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) { auto * memory = (ggml_vk_memory *)buffer->context; - if (ggml_vk_has_device()) { - ggml_vk_free_memory(*memory); - } + ggml_vk_free_memory(*memory); delete memory; + ggml_backend_kompute_device_unref(buffer->buft); } static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) { @@ -1840,7 +1957,7 @@ static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) static void ggml_backend_kompute_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { GGML_UNUSED(buffer); - const auto res = ggml_vk_get_tensor(tensor); + const auto res = ggml_vk_get_tensor_slice(tensor, offset, size); GGML_ASSERT(res); memcpy((char *)tensor->data + offset, data, size); @@ -1851,7 +1968,7 @@ static void ggml_backend_kompute_buffer_set_tensor(ggml_backend_buffer_t buffer, static void ggml_backend_kompute_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { GGML_UNUSED(buffer); - const auto res = ggml_vk_get_tensor(tensor); + const auto res = ggml_vk_get_tensor_slice(tensor, offset, size); GGML_ASSERT(res); komputeManager()->sequence()->eval({res}); @@ -1889,6 +2006,11 @@ static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffe static ggml_backend_buffer_t ggml_backend_kompute_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ggml_backend_kompute_device_ref(buft); auto * ctx = new ggml_vk_memory(ggml_vk_allocate(size)); + if (!ctx->primaryMemory) { + ggml_backend_kompute_device_unref(buft); + delete ctx; + return nullptr; // allocation failed + } return ggml_backend_buffer_init(buft, ggml_backend_kompute_buffer_i, ctx, size); } @@ -1912,15 +2034,20 @@ static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = { }; ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) { + // used only to prevent leaking context structs + static std::vector> buft_contexts; + static std::vector bufts = []() { std::vector vec; auto devices = ggml_vk_available_devices_internal(0); vec.reserve(devices.size()); for (const auto & dev : devices) { + auto *ctx = new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc); + buft_contexts.emplace_back(ctx); vec.push_back({ /* .iface = */ ggml_backend_kompute_buffer_type_interface, - /* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc) + /* .context = */ ctx, }); } return vec; @@ -1945,6 +2072,10 @@ static void ggml_backend_kompute_free(ggml_backend_t backend) { assert(ctx == s_kompute_context); s_kompute_context = nullptr; if (ctx != nullptr) { + if (!global_device_ref) { + // there are no more device refs, we can now fully cleanup Vulkan + komputeManager.freeInstance(); + } delete ctx; } diff --git a/ggml/src/ggml-rocm b/ggml/src/ggml-rocm new file mode 120000 index 0000000000000..4b84466af7cb3 --- /dev/null +++ b/ggml/src/ggml-rocm @@ -0,0 +1 @@ +ggml-cuda \ No newline at end of file diff --git a/ggml/src/ggml-rocm.cu b/ggml/src/ggml-rocm.cu new file mode 120000 index 0000000000000..8148fc234d18e --- /dev/null +++ b/ggml/src/ggml-rocm.cu @@ -0,0 +1 @@ +ggml-cuda.cu \ No newline at end of file diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 8efe32329693e..ef67309ba55bf 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -10,15 +10,17 @@ #include #include #include -#include -#include -#include -#include -#include #include +#include #include #include +#include #include +#include +#include +#include +#include +#include #include "ggml.h" #include "ggml-backend-impl.h" @@ -1870,7 +1872,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); -void ggml_vk_instance_init() { +static void ggml_vk_instance_init() { if (vk_instance_initialized) { return; } @@ -1924,14 +1926,15 @@ void ggml_vk_instance_init() { vk_instance.instance = vk::createInstance(instance_create_info); size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); if (devices_env != nullptr) { - std::string devices(devices_env); - std::replace(devices.begin(), devices.end(), ',', ' '); + std::string dev_indices(devices_env); + std::replace(dev_indices.begin(), dev_indices.end(), ',', ' '); - std::stringstream ss(devices); + std::stringstream ss(dev_indices); size_t tmp; while (ss >> tmp) { if(tmp >= num_available_devices) { @@ -1941,15 +1944,14 @@ void ggml_vk_instance_init() { vk_instance.device_indices.push_back(tmp); } } else { - std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); - // Make sure at least one device exists if (devices.empty()) { std::cerr << "ggml_vulkan: Error: No devices found." << std::endl; - GGML_ASSERT(false); + throw std::runtime_error("No Vulkan devices found"); } - // Default to using all dedicated GPUs + // Default to making all GPUs available + vk_instance.device_indices.reserve(devices.size()); for (size_t i = 0; i < devices.size(); i++) { vk::PhysicalDeviceProperties2 new_props; vk::PhysicalDeviceDriverProperties new_driver; @@ -1958,80 +1960,73 @@ void ggml_vk_instance_init() { new_driver.pNext = &new_id; devices[i].getProperties2(&new_props); - if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) { - // Check if there are two physical devices corresponding to the same GPU - auto old_device = std::find_if( - vk_instance.device_indices.begin(), - vk_instance.device_indices.end(), - [&devices, &new_id](const size_t k){ - vk::PhysicalDeviceProperties2 old_props; - vk::PhysicalDeviceIDProperties old_id; - old_props.pNext = &old_id; - devices[k].getProperties2(&old_props); - return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); - } - ); - if (old_device == vk_instance.device_indices.end()) { - vk_instance.device_indices.push_back(i); - } else { - // There can be two physical devices corresponding to the same GPU if there are 2 different drivers - // This can cause error when splitting layers aross the devices, need to keep only 1 - VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID"); - + // Check if there are two physical devices corresponding to the same GPU + auto old_device = std::find_if( + vk_instance.device_indices.begin(), + vk_instance.device_indices.end(), + [&devices, &new_id](const size_t k){ vk::PhysicalDeviceProperties2 old_props; - vk::PhysicalDeviceDriverProperties old_driver; - old_props.pNext = &old_driver; - devices[*old_device].getProperties2(&old_props); - - std::map driver_priorities {}; - int old_priority = std::numeric_limits::max(); - int new_priority = std::numeric_limits::max(); - - // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id - // Smaller number -> higher priority - switch (old_props.properties.vendorID) { - case VK_VENDOR_ID_AMD: - driver_priorities[vk::DriverId::eMesaRadv] = 1; - driver_priorities[vk::DriverId::eAmdOpenSource] = 2; - driver_priorities[vk::DriverId::eAmdProprietary] = 3; - break; - case VK_VENDOR_ID_INTEL: - driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1; - driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2; - break; - case VK_VENDOR_ID_NVIDIA: - driver_priorities[vk::DriverId::eNvidiaProprietary] = 1; + vk::PhysicalDeviceIDProperties old_id; + old_props.pNext = &old_id; + devices[k].getProperties2(&old_props); + return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); + } + ); + if (old_device == vk_instance.device_indices.end()) { + vk_instance.device_indices.push_back(i); + } else { + // There can be two physical devices corresponding to the same GPU if there are 2 different drivers + // This can cause error when splitting layers aross the devices, need to keep only 1 + VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID"); + + vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceDriverProperties old_driver; + old_props.pNext = &old_driver; + devices[*old_device].getProperties2(&old_props); + + std::map driver_priorities {}; + int old_priority = std::numeric_limits::max(); + int new_priority = std::numeric_limits::max(); + + // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id + // Smaller number -> higher priority + switch (old_props.properties.vendorID) { + case VK_VENDOR_ID_AMD: + driver_priorities[vk::DriverId::eMesaRadv] = 1; + driver_priorities[vk::DriverId::eAmdOpenSource] = 2; + driver_priorities[vk::DriverId::eAmdProprietary] = 3; + break; + case VK_VENDOR_ID_INTEL: + driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1; + driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2; + break; + case VK_VENDOR_ID_NVIDIA: + driver_priorities[vk::DriverId::eNvidiaProprietary] = 1; #if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235 - driver_priorities[vk::DriverId::eMesaNvk] = 2; + driver_priorities[vk::DriverId::eMesaNvk] = 2; #endif - break; - } + break; + } - if (driver_priorities.count(old_driver.driverID)) { - old_priority = driver_priorities[old_driver.driverID]; - } - if (driver_priorities.count(new_driver.driverID)) { - new_priority = driver_priorities[new_driver.driverID]; - } + if (driver_priorities.count(old_driver.driverID)) { + old_priority = driver_priorities[old_driver.driverID]; + } + if (driver_priorities.count(new_driver.driverID)) { + new_priority = driver_priorities[new_driver.driverID]; + } - if (new_priority < old_priority) { - auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device); - vk_instance.device_indices.erase(r, vk_instance.device_indices.end()); - vk_instance.device_indices.push_back(i); + if (new_priority < old_priority) { + auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device); + vk_instance.device_indices.erase(r, vk_instance.device_indices.end()); + vk_instance.device_indices.push_back(i); - VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName); - } - else { - VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl); - } + VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName); + } + else { + VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl); } } } - - // If no dedicated GPUs found, fall back to GPU 0 - if (vk_instance.device_indices.empty()) { - vk_instance.device_indices.push_back(0); - } } std::cerr << "ggml_vulkan: Found " << vk_instance.device_indices.size() << " Vulkan devices:" << std::endl; @@ -5818,6 +5813,85 @@ GGML_CALL static void ggml_vk_get_device_description(int device, char * descript snprintf(description, description_size, "%s", props.deviceName.data()); } +static std::list ggml_vk_available_devices_internal() { + std::list results; + ggml_vk_instance_init(); + + std::vector physical_devices = vk_instance.instance.enumeratePhysicalDevices(); + + std::unordered_map count_by_name; + + for (uint32_t dev_idx = 0; dev_idx < physical_devices.size(); dev_idx++) { + const auto & physical_device = physical_devices[dev_idx]; + + vk::PhysicalDeviceProperties dev_props; + physical_device.getProperties(&dev_props); + + vk::PhysicalDeviceMemoryProperties mem_props; + physical_device.getMemoryProperties(&mem_props); + + vk::DeviceSize heapSize = 0; + for (uint32_t i = 0; i < mem_props.memoryHeapCount; i++) { + const auto & heap = mem_props.memoryHeaps[i]; + if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + heapSize = heap.size; + break; + } + } + + std::string name(dev_props.deviceName.data()); + size_t n_idx = ++count_by_name[name]; + if (n_idx > 1) { + name += " (" + std::to_string(n_idx) + ")"; + } + + results.push_back({ + /* index = */ dev_idx, + /* type = */ VkPhysicalDeviceType(dev_props.deviceType), + /* heapSize = */ heapSize, + /* name = */ strdup(name.c_str()), + /* vendor = */ dev_props.vendorID + }); + } + + // std::list::sort is guaranteed to be stable + results.sort( + [](const ggml_vk_device & a, const ggml_vk_device & b) -> bool { + if (a.type != b.type) { + if (a.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return true; + if (b.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return false; + + if (a.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return true; + if (b.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return false; + } + return a.heapSize > b.heapSize; // descending + } + ); + + return results; +} + +// public API returns a C-style array +ggml_vk_device * ggml_vk_available_devices(size_t * count) { + auto devices = ggml_vk_available_devices_internal(); + *count = devices.size(); + if (devices.empty()) { + return nullptr; + } + + size_t nbytes = sizeof(ggml_vk_device) * devices.size(); + auto * arr = static_cast(malloc(nbytes)); + + int i = 0; + for (auto & d : devices) { arr[i++] = d; } + + return arr; +} + +void ggml_vk_device_destroy(ggml_vk_device * device) { + free(const_cast(device->name)); +} + // backend interface #define UNUSED GGML_UNUSED @@ -5869,6 +5943,11 @@ GGML_CALL static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) { return buffer->iface.get_name == ggml_backend_vk_buffer_get_name; } +size_t ggml_backend_vk_idx(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + return ctx->compute_ctx->idx; +} + GGML_CALL static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()"); ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; diff --git a/ggml/src/kompute b/ggml/src/kompute index 4565194ed7c32..f592b5bca3cbc 160000 --- a/ggml/src/kompute +++ b/ggml/src/kompute @@ -1 +1 @@ -Subproject commit 4565194ed7c32d1d2efa32ceab4d3c6cae006306 +Subproject commit f592b5bca3cbc169feb194218a086b18d618cca4 diff --git a/include/llama.h b/include/llama.h index c0fb53060eae4..ddff304b7ccc4 100644 --- a/include/llama.h +++ b/include/llama.h @@ -507,6 +507,9 @@ extern "C" { // to the decoder to start generating output sequence. For other models, it returns -1. LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); + // Returns true if the model is using the GPU/accelerator device + LLAMA_API bool llama_model_using_gpu(struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, @@ -878,6 +881,12 @@ extern "C" { // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // + // Model Info + // + LLAMA_API const char * llama_model_name(const struct llama_model * model); + LLAMA_API const char * llama_model_arch(const struct llama_model * model); + // // Vocab // diff --git a/poetry.lock b/poetry.lock index eb6baa6c749c0..942c01e9dfb2b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "atomicwrites" @@ -1066,13 +1066,13 @@ reference = "pytorch" [[package]] name = "tqdm" -version = "4.66.2" +version = "4.66.3" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.66.2-py3-none-any.whl", hash = "sha256:1ee4f8a893eb9bef51c6e35730cebf234d5d0b6bd112b0271e10ed7c24a02bd9"}, - {file = "tqdm-4.66.2.tar.gz", hash = "sha256:6cd52cdf0fef0e0f543299cfc96fec90d7b8a7e88745f411ec33eb44d5ed3531"}, + {file = "tqdm-4.66.3-py3-none-any.whl", hash = "sha256:4f41d54107ff9a223dca80b53efe4fb654c67efaba7f47bada3ee9d50e05bd53"}, + {file = "tqdm-4.66.3.tar.gz", hash = "sha256:23097a41eba115ba99ecae40d06444c15d1c0c698d527a01c6c8bd1c5d0647e5"}, ] [package.dependencies] diff --git a/src/llama.cpp b/src/llama.cpp index 20e85b3ebe5df..dff2913197196 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2679,6 +2679,13 @@ struct llama_model { std::vector rpc_servers; +#if defined(GGML_USE_METAL) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL) \ + || defined(GGML_USE_CLBLAST) || defined(GGML_USE_KOMPUTE) + bool using_gpu = true; +#else + bool using_gpu = false; +#endif + // gguf metadata std::unordered_map gguf_kv; @@ -2736,12 +2743,21 @@ struct llama_model { } }; +#ifdef GGML_USE_VULKAN +static bool vulkan_backend_initialized[GGML_VK_MAX_DEVICES] = {}; +#endif + struct llama_context { llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} ~llama_context() { ggml_backend_sched_free(sched); for (ggml_backend_t backend : backends) { +#ifdef GGML_USE_VULKAN + if (ggml_backend_is_vk(backend)) { + vulkan_backend_initialized[ggml_backend_vk_idx(backend)] = false; + } +#endif ggml_backend_free(backend); } @@ -5972,6 +5988,7 @@ static bool llm_load_tensors( model.buft_layer[i] = llama_default_buffer_type_cpu(true); } +#ifndef GGML_USE_KOMPUTE if (split_mode == LLAMA_SPLIT_MODE_LAYER) { // calculate the split points int device_count = llama_get_device_count(model); @@ -6009,7 +6026,9 @@ static bool llm_load_tensors( } else { model.buft_output = llama_default_buffer_type_cpu(true); } - } else { + } else +#endif + { ggml_backend_buffer_type_t split_buft; if (split_mode == LLAMA_SPLIT_MODE_ROW) { split_buft = llama_default_buffer_type_split(model, main_gpu, tensor_split); @@ -6017,6 +6036,12 @@ static bool llm_load_tensors( // LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported split_buft = llama_default_buffer_type_offload(model, main_gpu); } +#ifdef GGML_USE_KOMPUTE + // we can fall back to CPU buffer type in some cases + if (!strcmp(ggml_backend_buft_name(split_buft), "CPU")) { + model.using_gpu = false; + } +#endif // assign the repeating layers for (int i = i_gpu_start; i < n_layer; ++i) { model.buft_layer[i] = { @@ -7714,8 +7739,34 @@ static bool llm_load_tensors( return true; } +#ifdef GGML_USE_KOMPUTE +static const llm_arch LLM_KOMPUTE_SUPPORTED_ARCHES[] { + LLM_ARCH_LLAMA, + LLM_ARCH_FALCON, + LLM_ARCH_BAICHUAN, + LLM_ARCH_GPT2, + // LLM_ARCH_MPT, -- needs GGML_OP_ALIBI + LLM_ARCH_STARCODER, + // LLM_ARCH_PERSIMMON, -- needs GGML_OP_CONCAT + // LLM_ARCH_REFACT, -- needs GGML_OP_ALIBI + LLM_ARCH_BERT, + LLM_ARCH_NOMIC_BERT, + // LLM_ARCH_BLOOM, -- needs GGML_OP_ALIBI + LLM_ARCH_STABLELM, + LLM_ARCH_QWEN, + LLM_ARCH_QWEN2, + LLM_ARCH_PHI2, + // LLM_ARCH_PLAMO, -- unable to test + LLM_ARCH_CODESHELL, + LLM_ARCH_ORION, + LLM_ARCH_INTERNLM2, + LLM_ARCH_MINICPM, + LLM_ARCH_GEMMA, +}; +#endif + // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback -static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { +static int llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) { try { llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides); @@ -7749,25 +7800,36 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam return 0; } + int n_gpu_layers = params.n_gpu_layers; + + // NOTE: Metal and Kompute do no compute on the GPU with ngl=0, CUDA and Vulkan do + // TODO(cebtenzzre): What about other backends? #ifdef GGML_USE_KOMPUTE - if (params.n_gpu_layers > 0 && ( - !(model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) + auto & kparch = LLM_KOMPUTE_SUPPORTED_ARCHES; + if (!params.n_gpu_layers) { + model.using_gpu = false; + } else if ( + std::find(kparch, std::end(kparch), model.arch) == std::end(kparch) + || model.hparams.n_expert > 0 || !( model.ftype == LLAMA_FTYPE_ALL_F32 || model.ftype == LLAMA_FTYPE_MOSTLY_F16 || - model.ftype == LLAMA_FTYPE_MOSTLY_BF16 || model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 ) - )) { - // TODO(cebtenzzre): propagate this error outside of llama_load_model_from_file + ) { LLAMA_LOG_WARN("%s: disabling Kompute due to unsupported model arch or quantization\n", __func__); - params.n_gpu_layers = 0; + model.using_gpu = false; + n_gpu_layers = 0; + } +#elif defined(GGML_USE_METAL) + if (!params.n_gpu_layers) { + model.using_gpu = false; } #endif if (!llm_load_tensors( - ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock, + ml, model, n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock, params.progress_callback, params.progress_callback_user_data )) { return -2; @@ -18936,7 +18998,7 @@ int64_t llama_time_us(void) { struct llama_model * llama_load_model_from_file( const char * path_model, - struct llama_model_params params) { + struct llama_model_params params) { ggml_time_init(); llama_model * model = new llama_model; @@ -19175,6 +19237,8 @@ struct llama_context * llama_new_context_with_model( return nullptr; } if (model->split_mode == LLAMA_SPLIT_MODE_NONE) { + GGML_ASSERT(!vulkan_backend_initialized[model->main_gpu]); + vulkan_backend_initialized[model->main_gpu] = true; ggml_backend_t backend = ggml_backend_vk_init(model->main_gpu); if (backend == nullptr) { LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__); @@ -19184,6 +19248,8 @@ struct llama_context * llama_new_context_with_model( ctx->backends.push_back(backend); } else { for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) { + GGML_ASSERT(!vulkan_backend_initialized[device]); + vulkan_backend_initialized[device] = true; ggml_backend_t backend = ggml_backend_vk_init(device); if (backend == nullptr) { LLAMA_LOG_ERROR("%s: failed to initialize Vulkan%d backend\n", __func__, device); @@ -19218,7 +19284,7 @@ struct llama_context * llama_new_context_with_model( } } #elif defined(GGML_USE_KOMPUTE) - if (model->n_gpu_layers > 0) { + if (model->using_gpu) { auto * backend = ggml_backend_kompute_init(model->main_gpu); if (backend == nullptr) { LLAMA_LOG_ERROR("%s: failed to initialize Kompute backend\n", __func__); @@ -19584,6 +19650,10 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) { return model->hparams.dec_start_token_id; } +bool llama_model_using_gpu(struct llama_model * model) { + return model->using_gpu; +} + uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out, @@ -20073,7 +20143,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data data_ctx->write(&kv_used, sizeof(kv_used)); data_ctx->write(&v_trans, sizeof(v_trans)); - if (kv_buf_size) { + if (kv_buf_size && kv_head) { const size_t pre_kv_buf_size = data_ctx->get_size_written(); std::vector tmp_buf; @@ -20238,10 +20308,10 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { llama_kv_cache_clear(ctx); - if (kv_buf_size) { - const size_t pre_kv_buf_size = inp - src; + GGML_ASSERT(kv_self.total_size() >= kv_buf_size); - GGML_ASSERT(kv_self.total_size() >= kv_buf_size); + if (kv_buf_size && kv_head) { + const size_t pre_kv_buf_size = inp - src; for (int il = 0; il < (int) n_layer; ++il) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); @@ -21096,6 +21166,14 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id return it->second.data(); } +const char * llama_model_name(const struct llama_model * model) { + return model->name.c_str(); +} + +const char * llama_model_arch(const struct llama_model * model) { + return LLM_ARCH_NAMES.at(model->arch); +} + const char * llama_token_get_text(const struct llama_model * model, llama_token token) { GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE); return model->vocab.id_to_token[token].text.c_str();