-
Notifications
You must be signed in to change notification settings - Fork 19
Track allocation origin to fix cross-thread/stream attribution #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,12 +26,15 @@ | |
|
|
||
| #include <cuda_runtime_api.h> | ||
|
|
||
| #include <array> | ||
| #include <atomic> | ||
| #include <cstdint> | ||
| #include <exception> | ||
| #include <memory> | ||
| #include <mutex> | ||
| #include <stdexcept> | ||
| #include <string> | ||
| #include <unordered_map> | ||
|
|
||
| namespace cucascade { | ||
| namespace memory { | ||
|
|
@@ -42,6 +45,31 @@ using device_reserved_arena = impl_type::device_reserved_arena; | |
|
|
||
| namespace { | ||
|
|
||
| // Origin reservation per allocation; weak_ptr survives reset_tracker_state. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doesn't this just mean that we are calling reset_tracker_state before the work scheduled on a stream and thread is completed? It feels like we are missing some kind of synchronization before the reset is called. I don't htink the solution is to fix it by trying to hold onto it longer.. |
||
| struct alloc_origin_info { | ||
| std::weak_ptr<device_reserved_arena> alloc_arena_weak; | ||
| }; | ||
|
|
||
| // Sharded by ptr hash to avoid serializing every allocate/deallocate on a single mutex. | ||
| // 16 shards keeps contention below ~10% at 16 concurrent ops while staying lightweight. | ||
| constexpr std::size_t kAllocOriginShards = 16; | ||
|
|
||
| struct alloc_origin_shard { | ||
| std::mutex mutex; | ||
| std::unordered_map<void*, alloc_origin_info> map; | ||
| }; | ||
|
|
||
| inline alloc_origin_shard& alloc_origin_shard_for(void* ptr) | ||
| { | ||
| static std::array<alloc_origin_shard, kAllocOriginShards> shards; | ||
| // Drop the low 6 bits (alignment noise) before mixing into shard index. | ||
| return shards[(reinterpret_cast<std::uintptr_t>(ptr) >> 6) % kAllocOriginShards]; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| namespace { | ||
|
|
||
| struct stream_ordered_allocation_tracker : public impl_type::allocation_tracker_iface { | ||
| mutable std::mutex mutex; | ||
| std::unordered_map<cudaStream_t, std::unique_ptr<stream_ordered_tracker_state>> stream_stats_map; | ||
|
|
@@ -50,14 +78,18 @@ struct stream_ordered_allocation_tracker : public impl_type::allocation_tracker_ | |
|
|
||
| void reset_tracker_state(rmm::cuda_stream_view stream) override | ||
| { | ||
| std::lock_guard lock(mutex); | ||
| auto it = stream_stats_map.find(stream.value()); | ||
| if (it == stream_stats_map.end()) { return; } | ||
| stream_stats_map.erase(stream.value()); | ||
| std::unique_ptr<stream_ordered_tracker_state> released_state; | ||
| { | ||
| std::lock_guard lock(mutex); | ||
| auto it = stream_stats_map.find(stream.value()); | ||
| if (it == stream_stats_map.end()) { return; } | ||
| released_state = std::move(it->second); | ||
| stream_stats_map.erase(it); | ||
| } | ||
| } | ||
|
|
||
| void assign_reservation_to_tracker(rmm::cuda_stream_view stream, | ||
| std::unique_ptr<device_reserved_arena> arena, | ||
| std::shared_ptr<device_reserved_arena> arena, | ||
| std::unique_ptr<reservation_limit_policy> policy, | ||
| std::unique_ptr<oom_handling_policy> oom_policy) override | ||
| { | ||
|
|
@@ -67,8 +99,9 @@ struct stream_ordered_allocation_tracker : public impl_type::allocation_tracker_ | |
| throw rmm::logic_error("Stream already has reservation state set"); | ||
| } | ||
|
|
||
| stream_stats_map[stream.value()] = std::make_unique<stream_ordered_tracker_state>( | ||
| auto state = std::make_unique<stream_ordered_tracker_state>( | ||
| std::move(arena), std::move(policy), std::move(oom_policy)); | ||
| stream_stats_map[stream.value()] = std::move(state); | ||
| } | ||
|
|
||
| stream_ordered_tracker_state* get_tracker_state(rmm::cuda_stream_view stream) override | ||
|
|
@@ -95,11 +128,11 @@ struct ptds_allocation_tracker : public impl_type::allocation_tracker_iface { | |
|
|
||
| void reset_tracker_state([[maybe_unused]] rmm::cuda_stream_view stream) override | ||
| { | ||
| if (thread_reservation_state) { thread_reservation_state.reset(); } | ||
| thread_reservation_state.reset(); | ||
| } | ||
|
|
||
| void assign_reservation_to_tracker([[maybe_unused]] rmm::cuda_stream_view stream, | ||
| std::unique_ptr<device_reserved_arena> arena, | ||
| std::shared_ptr<device_reserved_arena> arena, | ||
| std::unique_ptr<reservation_limit_policy> policy, | ||
| std::unique_ptr<oom_handling_policy> oom_policy) override | ||
| { | ||
|
|
@@ -127,7 +160,7 @@ struct ptds_allocation_tracker : public impl_type::allocation_tracker_iface { | |
| } // namespace | ||
|
|
||
| stream_ordered_tracker_state::stream_ordered_tracker_state( | ||
| std::unique_ptr<device_reserved_arena> arena, | ||
| std::shared_ptr<device_reserved_arena> arena, | ||
| std::unique_ptr<reservation_limit_policy> res_policy, | ||
| std::unique_ptr<oom_handling_policy> oom_policy) | ||
| : memory_reservation(std::move(arena)), | ||
|
|
@@ -305,7 +338,7 @@ bool impl_type::attach_reservation_to_tracker( | |
|
|
||
| _allocation_tracker->assign_reservation_to_tracker( | ||
| stream, | ||
| std::unique_ptr<device_reserved_arena>( | ||
| std::shared_ptr<device_reserved_arena>( | ||
| dynamic_cast<device_reserved_arena*>(reserved_bytes->_arena.release())), | ||
| std::move(stream_reservation_policy), | ||
| std::move(stream_oom_policy)); | ||
|
|
@@ -363,11 +396,14 @@ void* impl_type::allocate(cuda::stream_ref stream, | |
| [[maybe_unused]] std::size_t alignment) | ||
| { | ||
| auto* reservation_state = _allocation_tracker->get_tracker_state(stream); | ||
| void* ptr = (reservation_state != nullptr) ? do_allocate_managed(bytes, reservation_state, stream) | ||
| : do_allocate_managed(bytes, stream); | ||
| if (reservation_state != nullptr) { | ||
| return do_allocate_managed(bytes, reservation_state, stream); | ||
| } else { | ||
| return do_allocate_managed(bytes, stream); | ||
| auto& shard = alloc_origin_shard_for(ptr); | ||
| std::lock_guard<std::mutex> lock{shard.mutex}; | ||
| shard.map[ptr] = {reservation_state->memory_reservation}; | ||
| } | ||
| return ptr; | ||
| } | ||
|
|
||
| void* impl_type::do_allocate_managed(std::size_t bytes, rmm::cuda_stream_view stream) | ||
|
|
@@ -443,22 +479,38 @@ void impl_type::deallocate(cuda::stream_ref stream, | |
| std::size_t bytes, | ||
| [[maybe_unused]] std::size_t alignment) noexcept | ||
| { | ||
| alloc_origin_info origin{}; | ||
| bool have_origin = false; | ||
| { | ||
| auto& shard = alloc_origin_shard_for(ptr); | ||
| std::lock_guard<std::mutex> lock{shard.mutex}; | ||
| auto it = shard.map.find(ptr); | ||
| if (it != shard.map.end()) { | ||
| origin = std::move(it->second); | ||
| have_origin = true; | ||
| shard.map.erase(it); | ||
| } | ||
| } | ||
| auto tracking_bytes = rmm::align_up(bytes, rmm::CUDA_ALLOCATION_ALIGNMENT); | ||
| auto upstream_reclaimed_bytes = tracking_bytes; | ||
| auto* reservation_state = _allocation_tracker->get_tracker_state(stream); | ||
| if (reservation_state != nullptr) { | ||
| auto* reservation = reservation_state->memory_reservation.get(); | ||
| auto reservation_size = static_cast<int64_t>(reservation->size()); | ||
|
|
||
| std::shared_ptr<device_reserved_arena> origin_arena_locked; | ||
| if (have_origin) { origin_arena_locked = origin.alloc_arena_weak.lock(); } | ||
|
|
||
| if (origin_arena_locked) { | ||
| auto* origin_arena = origin_arena_locked.get(); | ||
| auto reservation_size = static_cast<int64_t>(origin_arena->size()); | ||
| int64_t post_deallocation_size = | ||
| reservation->allocated_bytes.sub(static_cast<int64_t>(tracking_bytes)); | ||
| origin_arena->allocated_bytes.sub(static_cast<int64_t>(tracking_bytes)); | ||
| int64_t pre_deallocation_size = post_deallocation_size + static_cast<int64_t>(tracking_bytes); | ||
| if (pre_deallocation_size <= reservation_size) { | ||
| // if it was made using the reserved space | ||
| // entirely within reserved arena | ||
| upstream_reclaimed_bytes = 0; | ||
| } else if (post_deallocation_size < reservation_size) { | ||
| // if it was partially made using the reserved space | ||
| // partially over-reservation | ||
| upstream_reclaimed_bytes = static_cast<std::size_t>(pre_deallocation_size - reservation_size); | ||
| } | ||
| // else: entirely over-reservation — upstream_reclaimed_bytes stays at tracking_bytes | ||
| } | ||
| // Suppress false-positive null-dereference warnings from CCCL library code | ||
| #pragma GCC diagnostic push | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we just be managing this by keeping the stream_ordered_tracker_state alive which holds the arena?