Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit c27786f

Browse files
committed
allow arrays to be GC'ed, using weak_ptr when registering them
1 parent b63582e commit c27786f

File tree

6 files changed

+23
-5
lines changed

6 files changed

+23
-5
lines changed

src/MPIMediator.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616
using OutputAdapter = bitsery::OutputBufferAdapter<Buffer>;
1717
using InputAdapter = bitsery::InputBufferAdapter<Buffer>;
18-
using array_keeper_type = std::unordered_map<uint64_t, tensor_i::ptr_type>;
18+
using array_keeper_type = std::unordered_map<uint64_t, tensor_i::ptr_type::weak_type>;
1919
using locker = std::lock_guard<std::mutex>;
2020

2121
static array_keeper_type s_ak;
22-
static uint64_t s_last_id = 0;
22+
static uint64_t s_last_id = Mediator::LOCAL_ONLY;
2323
constexpr static int PULL_TAG = 4711;
2424
constexpr static int PUSH_TAG = 4712;
2525
static std::mutex ak_mutex;
@@ -53,6 +53,11 @@ uint64_t MPIMediator::register_array(tensor_i::ptr_type ary)
5353
return s_last_id;
5454
}
5555

56+
uint64_t MPIMediator::unregister_array(uint64_t id)
57+
{
58+
s_ak.erase(id);
59+
}
60+
5661
void MPIMediator::pull(rank_type from, const tensor_i & ary, const NDSlice & slice, void * rbuff)
5762
{
5863
MPI_Comm comm = MPI_COMM_WORLD;
@@ -123,8 +128,9 @@ void MPIMediator::listen()
123128
if(x == s_ak.end()) throw(std::runtime_error("Encountered pull request for unknown tensor."));
124129
// Wait for previous answer to complete so that we can re-use the buffer
125130
MPI_Wait(&request_out, MPI_STATUS_IGNORE);
126-
x->second->bufferize(slice, rbuff);
127-
if(slice.size() * x->second->item_size() != rbuff.size()) throw(std::runtime_error("Got unexpected buffer size."));
131+
auto ptr = x->second.lock();
132+
ptr->bufferize(slice, rbuff);
133+
if(slice.size() * ptr->item_size() != rbuff.size()) throw(std::runtime_error("Got unexpected buffer size."));
128134
MPI_Isend(rbuff.data(), rbuff.size(), MPI_CHAR, requester, PUSH_TAG, comm, &request_out);
129135
} while(true);
130136
// MPI_Cancel(&request_in);

src/include/ddptensor/MPIMediator.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class MPIMediator : public Mediator
1313
MPIMediator();
1414
virtual ~MPIMediator();
1515
virtual uint64_t register_array(tensor_i::ptr_type ary);
16+
virtual uint64_t unregister_array(uint64_t);
1617
virtual void pull(rank_type from, const tensor_i & ary, const NDSlice & slice, void * buffer);
1718

1819
protected:

src/include/ddptensor/Mediator.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ class NDSlice;
1010
class Mediator
1111
{
1212
public:
13+
enum : uint64_t {LOCAL_ONLY = 0};
1314
virtual ~Mediator() {}
1415
virtual uint64_t register_array(tensor_i::ptr_type ary) = 0;
16+
virtual uint64_t unregister_array(uint64_t) = 0;
1517
virtual void pull(rank_type from, const tensor_i & ary, const NDSlice & slice, void * buffer) = 0;
1618
};
1719

src/include/ddptensor/x.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ namespace x
4242
template<typename T>
4343
class DPTensorX : public DPTensorBaseX
4444
{
45-
uint64_t _id = 0;
45+
uint64_t _id = Mediator::LOCAL_ONLY;
4646
mutable rank_type _owner;
4747
PVSlice _slice;
4848
xt::xstrided_slice_vector _lslice;
@@ -117,6 +117,11 @@ namespace x
117117
_xarray = org;
118118
}
119119

120+
~DPTensorX()
121+
{
122+
if(_id != Mediator::LOCAL_ONLY && theMediator) theMediator->unregister_array(_id);
123+
}
124+
120125
bool is_sliced() const
121126
{
122127
return _slice.is_sliced();

test/test_linalg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,5 @@
1818
b = np.reshape(b, (6,5))
1919
c = np.dot(a, b)
2020
print(c)
21+
22+
dt.fini()

test/test_x.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ddptensor as dt
22
a = dt.ones([4,4], dt.float64)
3+
a = dt.ones([4,4], dt.float64)
34
b = dt.ones([4,4], dt.float64)
45
a += b
56
print(a)
@@ -9,4 +10,5 @@
910
print(a[0:1,0:1], float(a[0:1,0:1]), bool(a[0:1,0:1]), int(a[0:1,0:1]))
1011
print(a[0:2,0:2])
1112
print(float(a[1:2, 1:2]))
13+
1214
dt.fini()

0 commit comments

Comments
 (0)