Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit e48a709

Browse files
committed
fixing possible races
1 parent a6a8d78 commit e48a709

File tree

5 files changed

+27
-5
lines changed

5 files changed

+27
-5
lines changed

src/MPIMediator.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <thread>
55
#include <iostream>
66
#include <unordered_map>
7+
#include <mutex>
78
#include <bitsery/bitsery.h>
89
#include <bitsery/adapter/buffer.h>
910
#include <bitsery/traits/vector.h>
@@ -15,12 +16,13 @@
1516
using OutputAdapter = bitsery::OutputBufferAdapter<Buffer>;
1617
using InputAdapter = bitsery::InputBufferAdapter<Buffer>;
1718
using array_keeper_type = std::unordered_map<uint64_t, tensor_i::ptr_type>;
19+
using locker = std::lock_guard<std::mutex>;
1820

1921
static array_keeper_type s_ak;
2022
static uint64_t s_last_id = 0;
2123
constexpr static int PULL_TAG = 4711;
2224
constexpr static int PUSH_TAG = 4712;
23-
25+
static std::mutex ak_mutex;
2426

2527
MPIMediator::MPIMediator()
2628
: _listener(&MPIMediator::listen, this)
@@ -40,11 +42,13 @@ MPIMediator::~MPIMediator()
4042
ser.adapter().flush();
4143
MPI_Send(buff.data(), buff.size(), MPI_CHAR, rank, PULL_TAG, MPI_COMM_WORLD);
4244
_listener.join();
45+
locker _l(ak_mutex);
4346
s_ak.clear();
4447
}
4548

4649
uint64_t MPIMediator::register_array(tensor_i::ptr_type ary)
4750
{
51+
locker _l(ak_mutex);
4852
s_ak[++s_last_id] = ary;
4953
return s_last_id;
5054
}
@@ -114,6 +118,7 @@ void MPIMediator::listen()
114118
MPI_Irecv(buff.data(), buff.size(), MPI_CHAR, MPI_ANY_SOURCE, PULL_TAG, comm, &request_in);
115119

116120
// Now find the array in question and send back its bufferized slice
121+
locker _l(ak_mutex);
117122
auto x = s_ak.find(id);
118123
if(x == s_ak.end()) throw(std::runtime_error("Encountered pull request for unknown tensor."));
119124
// Wait for previous answer to complete so that we can re-use the buffer

src/MPITransceiver.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ static MPI_Op to_mpi(RedOpType o)
5656
}
5757
}
5858

59+
60+
61+
void MPITransceiver::barrier()
62+
{
63+
MPI_Barrier(MPI_COMM_WORLD);
64+
}
65+
5966
void MPITransceiver::bcast(void * ptr, size_t N, rank_type root)
6067
{
6168
MPI_Bcast(ptr, N, MPI_CHAR, root, MPI_COMM_WORLD);

src/include/ddptensor/MPITransceiver.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class MPITransceiver : public Transceiver
1919
return _rank;
2020
}
2121

22+
23+
virtual void barrier();
2224
virtual void bcast(void * ptr, size_t N, rank_type root);
2325
virtual void reduce_all(void * inout, DType T, size_t N, RedOpType op);
2426

src/include/ddptensor/Transceiver.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ class Transceiver
1212
virtual rank_type nranks() const = 0;
1313
virtual rank_type rank() const = 0;
1414

15+
// Barrier
16+
virtual void barrier() = 0;
17+
1518
// Broadcast data from root to all other processes
1619
// @param[inout] ptr on root: pointer to data to be sent
1720
// on all other processes: pointer to buffer to store received data

src/include/ddptensor/ddptensor_impl.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ class dtensor_impl : public tensor_i
292292
auto ptr = buff.ptr;
293293
auto pylen = VPROD(buff.shape);
294294
assert(buff.itemsize == sizeof(T));
295-
theTransceiver->reduce_all(ptr, DTYPE<T>::value, pylen, red_op(op));
295+
theTransceiver->reduce_all(ptr, dtype(), pylen, red_op(op));
296296
return create_dtensor(pvslice(), new_shape, ary, REPLICATED);
297297
}
298298

@@ -404,6 +404,7 @@ class dtensor_impl : public tensor_i
404404
NDSlice my_norm_slice = g_slc_view.map_slice(my_slice);
405405
std::cerr << "my_norm_slice: " << my_norm_slice << std::endl;
406406

407+
theTransceiver->barrier();
407408
_set_slice(cast(val), my_norm_slice, this, my_slice);
408409
}
409410

@@ -431,9 +432,13 @@ class dtensor_impl : public tensor_i
431432
py::object get_slice(const NDSlice & slice) const
432433
{
433434
auto shp = slice.shape();
434-
auto out = create_dtensor(PVSlice(shp, NOSPLIT), shp, DTYPE<T>::value, "empty");
435-
_set_slice(this, slice, cast(out), {shp});
436-
return cast(out)->_pyarray;
435+
// Create dtensor without creating id: do not use create_dtensor
436+
py::dict kwa;
437+
kwa["dtype"] = get_impl_dtype<T>();
438+
auto ary = _array_ns.attr("empty")(_make_tuple(shp), kwa);
439+
auto out = dtensor_impl<T>(PVSlice(shp, NOSPLIT), shp, ary, theTransceiver->rank());
440+
_set_slice(this, slice, &out, {shp});
441+
return out._pyarray;
437442
}
438443

439444
std::string __repr__() const

0 commit comments

Comments
 (0)