|
15 | 15 |
|
16 | 16 | using OutputAdapter = bitsery::OutputBufferAdapter<Buffer>; |
17 | 17 | using InputAdapter = bitsery::InputBufferAdapter<Buffer>; |
18 | | -using array_keeper_type = std::unordered_map<uint64_t, tensor_i::ptr_type>; |
| 18 | +using array_keeper_type = std::unordered_map<uint64_t, tensor_i::ptr_type::weak_type>; |
19 | 19 | using locker = std::lock_guard<std::mutex>; |
20 | 20 |
|
21 | 21 | static array_keeper_type s_ak; |
22 | | -static uint64_t s_last_id = 0; |
| 22 | +static uint64_t s_last_id = Mediator::LOCAL_ONLY; |
23 | 23 | constexpr static int PULL_TAG = 4711; |
24 | 24 | constexpr static int PUSH_TAG = 4712; |
25 | 25 | static std::mutex ak_mutex; |
@@ -53,6 +53,11 @@ uint64_t MPIMediator::register_array(tensor_i::ptr_type ary) |
53 | 53 | return s_last_id; |
54 | 54 | } |
55 | 55 |
|
| 56 | +uint64_t MPIMediator::unregister_array(uint64_t id) |
| 57 | +{ |
| 58 | + s_ak.erase(id); |
| 59 | +} |
| 60 | + |
56 | 61 | void MPIMediator::pull(rank_type from, const tensor_i & ary, const NDSlice & slice, void * rbuff) |
57 | 62 | { |
58 | 63 | MPI_Comm comm = MPI_COMM_WORLD; |
@@ -123,8 +128,9 @@ void MPIMediator::listen() |
123 | 128 | if(x == s_ak.end()) throw(std::runtime_error("Encountered pull request for unknown tensor.")); |
124 | 129 | // Wait for previous answer to complete so that we can re-use the buffer |
125 | 130 | MPI_Wait(&request_out, MPI_STATUS_IGNORE); |
126 | | - x->second->bufferize(slice, rbuff); |
127 | | - if(slice.size() * x->second->item_size() != rbuff.size()) throw(std::runtime_error("Got unexpected buffer size.")); |
| 131 | + auto ptr = x->second.lock(); |
| 132 | + ptr->bufferize(slice, rbuff); |
| 133 | + if(slice.size() * ptr->item_size() != rbuff.size()) throw(std::runtime_error("Got unexpected buffer size.")); |
128 | 134 | MPI_Isend(rbuff.data(), rbuff.size(), MPI_CHAR, requester, PUSH_TAG, comm, &request_out); |
129 | 135 | } while(true); |
130 | 136 | // MPI_Cancel(&request_in); |
|
0 commit comments