diff --git a/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.cpp b/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.cpp index 0fe7275f..a27aea9a 100644 --- a/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.cpp +++ b/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.cpp @@ -9,6 +9,7 @@ **/ #include "shared_mem_buffer.h" #include +#include SharedMemBuffer::SharedMemBuffer() { buffer_ = nullptr; @@ -17,7 +18,7 @@ SharedMemBuffer::SharedMemBuffer() { SharedMemBuffer::~SharedMemBuffer() { if (buffer_) { - free(buffer_); + ::operator delete(buffer_, std::align_val_t(64)); } } @@ -28,9 +29,9 @@ void SharedMemBuffer::alloc(void* object, std::vector size_) { if (buffer_) { - free(buffer_); + ::operator delete(buffer_, std::align_val_t(64)); } - buffer_ = std::aligned_alloc(64, size); + buffer_ = ::operator new(size, std::align_val_t(64)); size_ = size; for (auto& obj_requests : hist_requests_) { diff --git a/setup.py b/setup.py index c91d9dc2..14eefdaf 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,10 @@ "pytorch-triton-xpu==3.3.0" ] else: - triton_dep = ["triton>=3.2"] + triton_dep = [ + "triton >= 3.2; sys_platform != 'win32' and sys_platform != 'Windows'", + "triton-windows >= 3.2; sys_platform == 'win32' or sys_platform == 'Windows'" + ] with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1"