Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit c63171a

Browse files
committed
fixing shutdown
1 parent c427db6 commit c63171a

File tree

12 files changed

+60
-11
lines changed

12 files changed

+60
-11
lines changed

ddptensor/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# At this point there are no checks of input arguments whatsoever, arguments
1414
# are simply forwarded as-is.
1515

16+
_bool = bool
1617
from . import _ddptensor as _cdt
1718
from ._ddptensor import (
1819
FLOAT64 as float64,
@@ -26,15 +27,22 @@
2627
UINT16 as uint16,
2728
UINT8 as uint8,
2829
BOOL as bool,
29-
init,
30+
init as _init,
3031
fini,
3132
sync
3233
)
34+
3335
from .ddptensor import dtensor
3436
from os import getenv
3537
from . import array_api as api
3638
from . import spmd
3739

40+
_ddpt_cw = _bool(int(getenv('DDPT_CW', True)))
41+
42+
def init(cw=None):
43+
cw = _ddpt_cw if cw is None else cw
44+
_init(cw)
45+
3846
for op in api.api_categories["EWBinOp"]:
3947
if not op.startswith("__"):
4048
OP = op.upper()

src/Deferred.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ Deferred::ptr_type Deferred::undefer_next()
3737
return r;
3838
}
3939

40+
void Deferred::fini()
41+
{
42+
_deferred.clear();
43+
}
44+
4045
void process_promises()
4146
{
4247
while(true) {
@@ -55,3 +60,4 @@ void sync()
5560
std::this_thread::sleep_for(std::chrono::milliseconds(1));
5661
}
5762
}
63+

src/MPIMediator.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@ static std::mutex ak_mutex;
2121
void send_to_workers(const Deferred::ptr_type & dfrd, bool self = false);
2222

2323
MPIMediator::MPIMediator()
24-
: _listener(&MPIMediator::listen, this)
24+
: _listener(nullptr)
2525
{
26+
MPI_Comm comm = MPI_COMM_WORLD;
27+
int sz;
28+
MPI_Comm_size(comm, &sz);
29+
if(sz > 1)
30+
_listener = new std::thread(&MPIMediator::listen, this);
2631
}
2732

2833
MPIMediator::~MPIMediator()
@@ -36,7 +41,11 @@ MPIMediator::~MPIMediator()
3641
if(is_cw() && rank == 0) to_workers(nullptr);
3742
MPI_Barrier(comm);
3843
if(!is_cw() || rank == 0) send_to_workers(nullptr, true);
39-
_listener.join();
44+
if(_listener) {
45+
_listener->join();
46+
delete _listener;
47+
_listener = nullptr;
48+
}
4049
}
4150

4251
void MPIMediator::pull(rank_type from, id_type guid, const NDSlice & slice, void * rbuff)

src/MPITransceiver.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ MPITransceiver::MPITransceiver()
1515
throw std::runtime_error("Your MPI implementation is not MPI_THREAD_MULTIPLE. "
1616
"Please use a thread-safe MPI implementation.");
1717
}
18+
} else {
19+
std::cerr << "MPI already initialized\n";
1820
}
1921
int nranks, rank;
2022
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
@@ -23,6 +25,14 @@ MPITransceiver::MPITransceiver()
2325
_rank = rank;
2426
};
2527

28+
MPITransceiver::~MPITransceiver()
29+
{
30+
int flag;
31+
MPI_Finalized(&flag);
32+
if(!flag)
33+
MPI_Finalize();
34+
}
35+
2636
static MPI_Datatype to_mpi(DTypeId T)
2737
{
2838
switch(T) {

src/Registry.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,10 @@ namespace Registry {
3737
locker _l(_mutex);
3838
_keeper.erase(id);
3939
}
40+
41+
void fini()
42+
{
43+
locker _l(_mutex);
44+
_keeper.clear();
45+
}
4046
}

src/ddptensor.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,13 @@ bool is_cw()
5353
return _is_cw && theTransceiver->nranks() > 1;
5454
}
5555

56+
static bool inited = false;
57+
static bool finied = false;
58+
5659
// users currently need to call fini to make MPI terminate gracefully
5760
void fini()
5861
{
62+
if(finied) return;
5963
delete theMediator; // stop task is sent in here
6064
theMediator = nullptr;
6165
if(pprocessor) {
@@ -65,10 +69,15 @@ void fini()
6569
}
6670
delete theTransceiver;
6771
theTransceiver = nullptr;
72+
Deferred::fini();
73+
Registry::fini();
74+
inited = false;
75+
finied = true;
6876
}
6977

7078
void init(bool cw)
7179
{
80+
if(inited) return;
7281
if(cw) {
7382
_is_cw = true;
7483
if(theTransceiver->rank()) {
@@ -78,6 +87,8 @@ void init(bool cw)
7887
}
7988
}
8089
pprocessor = new std::thread(process_promises);
90+
inited = true;
91+
finied = false;
8192
}
8293

8394
// #########################################################################

src/include/ddptensor/Deferred.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct Deferred : tensor_i::promise_type
2424

2525
static future_type defer(ptr_type &&, bool);
2626
static ptr_type undefer_next();
27+
static void fini();
2728
};
2829

2930
template<typename T, typename... Ts>

src/include/ddptensor/MPIMediator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class MPIMediator : public Mediator
99
{
10-
std::thread _listener;
10+
std::thread * _listener;
1111

1212
public:
1313
MPIMediator();

src/include/ddptensor/MPITransceiver.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class MPITransceiver : public Transceiver
88
{
99
public:
1010
MPITransceiver();
11+
~MPITransceiver();
1112

1213
rank_type nranks() const
1314
{

src/include/ddptensor/Registry.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ namespace Registry {
1717
extern void put(id_type id, tensor_i::ptr_type ptr);
1818
tensor_i::ptr_type get(id_type id);
1919
void del(id_type id);
20+
void fini();
2021
};

0 commit comments

Comments
 (0)