Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class reservation_aware_resource_adaptor_impl {
* @brief Reservation state
*/
struct stream_ordered_tracker_state {
std::unique_ptr<device_reserved_arena>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we just be managing this by keeping the stream_ordered_tracker_state alive which holds the arena?

// shared_ptr so the alloc-origin map can hold a weak_ptr that outlives reset_tracker_state.
std::shared_ptr<device_reserved_arena>
memory_reservation; /// Stream memory reservation (may be null)
std::unique_ptr<reservation_limit_policy>
reservation_policy; /// Reservation policy for this stream
Expand All @@ -87,7 +88,7 @@ class reservation_aware_resource_adaptor_impl {
friend class reservation_aware_resource_adaptor_impl;

explicit stream_ordered_tracker_state(
std::unique_ptr<device_reserved_arena> arena,
std::shared_ptr<device_reserved_arena> arena,
std::unique_ptr<reservation_limit_policy> reservation_policy,
std::unique_ptr<oom_handling_policy> oom_policy);

Expand All @@ -108,7 +109,7 @@ class reservation_aware_resource_adaptor_impl {
virtual void reset_tracker_state(rmm::cuda_stream_view stream) = 0;

virtual void assign_reservation_to_tracker(rmm::cuda_stream_view stream,
std::unique_ptr<device_reserved_arena> reservation,
std::shared_ptr<device_reserved_arena> reservation,
std::unique_ptr<reservation_limit_policy> policy,
std::unique_ptr<oom_handling_policy> oom_policy) = 0;

Expand Down
92 changes: 72 additions & 20 deletions src/memory/reservation_aware_resource_adaptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@

#include <cuda_runtime_api.h>

#include <array>
#include <atomic>
#include <cstdint>
#include <exception>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <string>
#include <unordered_map>

namespace cucascade {
namespace memory {
Expand All @@ -42,6 +45,31 @@ using device_reserved_arena = impl_type::device_reserved_arena;

namespace {

// Origin reservation per allocation; weak_ptr survives reset_tracker_state.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this just mean that we are calling reset_tracker_state before the work scheduled on a stream and thread is completed? It feels like we are missing some kind of synchronization before the reset is called. I don't htink the solution is to fix it by trying to hold onto it longer..

struct alloc_origin_info {
std::weak_ptr<device_reserved_arena> alloc_arena_weak;
};

// Sharded by ptr hash to avoid serializing every allocate/deallocate on a single mutex.
// 16 shards keeps contention below ~10% at 16 concurrent ops while staying lightweight.
constexpr std::size_t kAllocOriginShards = 16;

struct alloc_origin_shard {
std::mutex mutex;
std::unordered_map<void*, alloc_origin_info> map;
};

inline alloc_origin_shard& alloc_origin_shard_for(void* ptr)
{
static std::array<alloc_origin_shard, kAllocOriginShards> shards;
// Drop the low 6 bits (alignment noise) before mixing into shard index.
return shards[(reinterpret_cast<std::uintptr_t>(ptr) >> 6) % kAllocOriginShards];
}

} // namespace

namespace {

struct stream_ordered_allocation_tracker : public impl_type::allocation_tracker_iface {
mutable std::mutex mutex;
std::unordered_map<cudaStream_t, std::unique_ptr<stream_ordered_tracker_state>> stream_stats_map;
Expand All @@ -50,14 +78,18 @@ struct stream_ordered_allocation_tracker : public impl_type::allocation_tracker_

void reset_tracker_state(rmm::cuda_stream_view stream) override
{
std::lock_guard lock(mutex);
auto it = stream_stats_map.find(stream.value());
if (it == stream_stats_map.end()) { return; }
stream_stats_map.erase(stream.value());
std::unique_ptr<stream_ordered_tracker_state> released_state;
{
std::lock_guard lock(mutex);
auto it = stream_stats_map.find(stream.value());
if (it == stream_stats_map.end()) { return; }
released_state = std::move(it->second);
stream_stats_map.erase(it);
}
}

void assign_reservation_to_tracker(rmm::cuda_stream_view stream,
std::unique_ptr<device_reserved_arena> arena,
std::shared_ptr<device_reserved_arena> arena,
std::unique_ptr<reservation_limit_policy> policy,
std::unique_ptr<oom_handling_policy> oom_policy) override
{
Expand All @@ -67,8 +99,9 @@ struct stream_ordered_allocation_tracker : public impl_type::allocation_tracker_
throw rmm::logic_error("Stream already has reservation state set");
}

stream_stats_map[stream.value()] = std::make_unique<stream_ordered_tracker_state>(
auto state = std::make_unique<stream_ordered_tracker_state>(
std::move(arena), std::move(policy), std::move(oom_policy));
stream_stats_map[stream.value()] = std::move(state);
}

stream_ordered_tracker_state* get_tracker_state(rmm::cuda_stream_view stream) override
Expand All @@ -95,11 +128,11 @@ struct ptds_allocation_tracker : public impl_type::allocation_tracker_iface {

void reset_tracker_state([[maybe_unused]] rmm::cuda_stream_view stream) override
{
if (thread_reservation_state) { thread_reservation_state.reset(); }
thread_reservation_state.reset();
}

void assign_reservation_to_tracker([[maybe_unused]] rmm::cuda_stream_view stream,
std::unique_ptr<device_reserved_arena> arena,
std::shared_ptr<device_reserved_arena> arena,
std::unique_ptr<reservation_limit_policy> policy,
std::unique_ptr<oom_handling_policy> oom_policy) override
{
Expand Down Expand Up @@ -127,7 +160,7 @@ struct ptds_allocation_tracker : public impl_type::allocation_tracker_iface {
} // namespace

stream_ordered_tracker_state::stream_ordered_tracker_state(
std::unique_ptr<device_reserved_arena> arena,
std::shared_ptr<device_reserved_arena> arena,
std::unique_ptr<reservation_limit_policy> res_policy,
std::unique_ptr<oom_handling_policy> oom_policy)
: memory_reservation(std::move(arena)),
Expand Down Expand Up @@ -305,7 +338,7 @@ bool impl_type::attach_reservation_to_tracker(

_allocation_tracker->assign_reservation_to_tracker(
stream,
std::unique_ptr<device_reserved_arena>(
std::shared_ptr<device_reserved_arena>(
dynamic_cast<device_reserved_arena*>(reserved_bytes->_arena.release())),
std::move(stream_reservation_policy),
std::move(stream_oom_policy));
Expand Down Expand Up @@ -363,11 +396,14 @@ void* impl_type::allocate(cuda::stream_ref stream,
[[maybe_unused]] std::size_t alignment)
{
auto* reservation_state = _allocation_tracker->get_tracker_state(stream);
void* ptr = (reservation_state != nullptr) ? do_allocate_managed(bytes, reservation_state, stream)
: do_allocate_managed(bytes, stream);
if (reservation_state != nullptr) {
return do_allocate_managed(bytes, reservation_state, stream);
} else {
return do_allocate_managed(bytes, stream);
auto& shard = alloc_origin_shard_for(ptr);
std::lock_guard<std::mutex> lock{shard.mutex};
shard.map[ptr] = {reservation_state->memory_reservation};
}
return ptr;
}

void* impl_type::do_allocate_managed(std::size_t bytes, rmm::cuda_stream_view stream)
Expand Down Expand Up @@ -443,22 +479,38 @@ void impl_type::deallocate(cuda::stream_ref stream,
std::size_t bytes,
[[maybe_unused]] std::size_t alignment) noexcept
{
alloc_origin_info origin{};
bool have_origin = false;
{
auto& shard = alloc_origin_shard_for(ptr);
std::lock_guard<std::mutex> lock{shard.mutex};
auto it = shard.map.find(ptr);
if (it != shard.map.end()) {
origin = std::move(it->second);
have_origin = true;
shard.map.erase(it);
}
}
auto tracking_bytes = rmm::align_up(bytes, rmm::CUDA_ALLOCATION_ALIGNMENT);
auto upstream_reclaimed_bytes = tracking_bytes;
auto* reservation_state = _allocation_tracker->get_tracker_state(stream);
if (reservation_state != nullptr) {
auto* reservation = reservation_state->memory_reservation.get();
auto reservation_size = static_cast<int64_t>(reservation->size());

std::shared_ptr<device_reserved_arena> origin_arena_locked;
if (have_origin) { origin_arena_locked = origin.alloc_arena_weak.lock(); }

if (origin_arena_locked) {
auto* origin_arena = origin_arena_locked.get();
auto reservation_size = static_cast<int64_t>(origin_arena->size());
int64_t post_deallocation_size =
reservation->allocated_bytes.sub(static_cast<int64_t>(tracking_bytes));
origin_arena->allocated_bytes.sub(static_cast<int64_t>(tracking_bytes));
int64_t pre_deallocation_size = post_deallocation_size + static_cast<int64_t>(tracking_bytes);
if (pre_deallocation_size <= reservation_size) {
// if it was made using the reserved space
// entirely within reserved arena
upstream_reclaimed_bytes = 0;
} else if (post_deallocation_size < reservation_size) {
// if it was partially made using the reserved space
// partially over-reservation
upstream_reclaimed_bytes = static_cast<std::size_t>(pre_deallocation_size - reservation_size);
}
// else: entirely over-reservation — upstream_reclaimed_bytes stays at tracking_bytes
}
// Suppress false-positive null-dereference warnings from CCCL library code
#pragma GCC diagnostic push
Expand Down
Loading