Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 93 additions & 12 deletions mlx/backend/no_gpu/allocator.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2026 Apple Inc.

#include <algorithm>
#include <mutex>

#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include "mlx/memory.h"

#ifdef __APPLE__
Expand All @@ -20,8 +21,14 @@ namespace mlx::core {

namespace allocator {

struct CpuCachedBuffer {
void* ptr;
size_t size;
CpuCachedBuffer* next_free; // intrusive freelist for object pooling
};

class CommonAllocator : public Allocator {
/** A general CPU allocator. */
/** A general CPU allocator with buffer caching. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
Expand All @@ -46,14 +53,65 @@ class CommonAllocator : public Allocator {
return limit;
}

size_t get_cache_memory() const {
return buffer_cache_.cache_size();
}
size_t set_cache_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(cache_limit_, limit);
if (buffer_cache_.cache_size() > cache_limit_) {
buffer_cache_.release_cached_buffers(
buffer_cache_.cache_size() - cache_limit_);
}
return limit;
}
void clear_cache() {
std::unique_lock lk(mutex_);
buffer_cache_.clear();
}

private:
CpuCachedBuffer* pool_head_ = nullptr;

CpuCachedBuffer* alloc_ccb(void* ptr, size_t sz) {
CpuCachedBuffer* ccb;
if (pool_head_) {
ccb = pool_head_;
pool_head_ = ccb->next_free;
ccb->ptr = ptr;
ccb->size = sz;
} else {
ccb = new CpuCachedBuffer{ptr, sz, nullptr};
}
return ccb;
}

void free_ccb(CpuCachedBuffer* ccb) {
ccb->next_free = pool_head_;
pool_head_ = ccb;
}

size_t memory_limit_;
size_t cache_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::mutex mutex_;
CommonAllocator() : memory_limit_(0.8 * get_memory_size()) {
mutable std::mutex mutex_;
mutable BufferCache<CpuCachedBuffer> buffer_cache_;
CommonAllocator()
: memory_limit_(0.8 * get_memory_size()),
cache_limit_(32UL << 20), // 32 MB default cache limit
buffer_cache_(
/* page_size = */
4096,
/* get_size = */
[](CpuCachedBuffer* b) { return b->size; },
/* free = */
[this](CpuCachedBuffer* b) {
std::free(b->ptr);
free_ccb(b);
}) {
if (memory_limit_ == 0) {
memory_limit_ = 1UL << 33;
memory_limit_ = 1ULL << 33;
}
};

Expand All @@ -77,21 +135,43 @@ void* Buffer::raw_ptr() {
}

Buffer CommonAllocator::malloc(size_t size) {
std::unique_lock lk(mutex_);
// Try cache first
CpuCachedBuffer* cached = buffer_cache_.reuse_from_cache(size);
if (cached) {
void* ptr = cached->ptr;
size_t alloc_size = cached->size;
free_ccb(cached);
active_memory_ += alloc_size;
peak_memory_ = std::max(active_memory_, peak_memory_);
return Buffer{ptr};
}
lk.unlock();

// Cache miss: allocate from OS
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
std::unique_lock lk(mutex_);
lk.lock();
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
return Buffer{ptr};
}

void CommonAllocator::free(Buffer buffer) {
auto sz = size(buffer);
std::free(buffer.ptr());
std::unique_lock lk(mutex_);
active_memory_ -= sz;

if (sz > 0 && buffer_cache_.cache_size() + sz <= cache_limit_) {
// Add to cache for reuse
auto* entry = alloc_ccb(buffer.ptr(), sz);
buffer_cache_.recycle_to_cache(entry);
} else {
lk.unlock();
std::free(buffer.ptr());
}
}

size_t CommonAllocator::size(Buffer buffer) const {
Expand Down Expand Up @@ -119,16 +199,17 @@ size_t get_memory_limit() {
return allocator::common_allocator().get_memory_limit();
}

// No-ops for common allocator
size_t get_cache_memory() {
return 0;
return allocator::common_allocator().get_cache_memory();
}
size_t set_cache_limit(size_t) {
return 0;
size_t set_cache_limit(size_t limit) {
return allocator::common_allocator().set_cache_limit(limit);
}
size_t set_wired_limit(size_t) {
return 0;
}
void clear_cache() {}
void clear_cache() {
allocator::common_allocator().clear_cache();
}

} // namespace mlx::core
19 changes: 19 additions & 0 deletions tests/allocator_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "doctest/doctest.h"

#include "mlx/allocator.h"
#include "mlx/memory.h"

using namespace mlx::core;

Expand Down Expand Up @@ -39,3 +40,21 @@ TEST_CASE("test large allocations") {
allocator::free(buffer);
}
}

TEST_CASE("test cached allocation keeps capacity") {
auto old_limit = set_cache_limit(1 << 20);
clear_cache();

auto large = allocator::malloc(8192);
allocator::free(large);
auto cached = get_cache_memory();
CHECK_GE(cached, 8192);

auto small = allocator::malloc(6000);
CHECK_GE(allocator::allocator().size(small), cached);
allocator::free(small);
CHECK_GE(get_cache_memory(), cached);

clear_cache();
set_cache_limit(old_limit);
}
Loading