Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit eaa83b5

Browse files
committed
introducing idtr, separating of python-dependent features
1 parent e684d81 commit eaa83b5

36 files changed

+577
-371
lines changed

CMakeLists.txt

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,45 @@ add_custom_command(
8282
# ============
8383
# Target
8484
# ============
85-
FILE(GLOB MyCppSources ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/include/ddptensor/*.hpp)
86-
set(MyCppSources ${MyCppSources} ${PROJECT_SOURCE_DIR}/src/jit/mlir.cpp ${P2C_HPP})
8785

88-
pybind11_add_module(_ddptensor MODULE ${MyCppSources})
86+
FILE(GLOB Hpps ${PROJECT_SOURCE_DIR}/src/include/ddptensor/*.hpp)
87+
set(Hpps ${Hpps} ${P2C_HPP})
88+
89+
set(DDPTSrcs
90+
${PROJECT_SOURCE_DIR}/src/ddptensor.cpp
91+
${PROJECT_SOURCE_DIR}/src/Creator.cpp
92+
${PROJECT_SOURCE_DIR}/src/EWBinOp.cpp
93+
${PROJECT_SOURCE_DIR}/src/EWUnyOp.cpp
94+
${PROJECT_SOURCE_DIR}/src/IEWBinOp.cpp
95+
${PROJECT_SOURCE_DIR}/src/IO.cpp
96+
${PROJECT_SOURCE_DIR}/src/LinAlgOp.cpp
97+
${PROJECT_SOURCE_DIR}/src/ManipOp.cpp
98+
${PROJECT_SOURCE_DIR}/src/Random.cpp
99+
${PROJECT_SOURCE_DIR}/src/ReduceOp.cpp
100+
${PROJECT_SOURCE_DIR}/src/Service.cpp
101+
${PROJECT_SOURCE_DIR}/src/SetGetItem.cpp
102+
)
103+
set(IDTRSrcs
104+
${PROJECT_SOURCE_DIR}/src/idtr.cpp
105+
${PROJECT_SOURCE_DIR}/src/CollComm.cpp
106+
${PROJECT_SOURCE_DIR}/src/Deferred.cpp
107+
${PROJECT_SOURCE_DIR}/src/Factory.cpp
108+
${PROJECT_SOURCE_DIR}/src/Mediator.cpp
109+
${PROJECT_SOURCE_DIR}/src/MPIMediator.cpp
110+
${PROJECT_SOURCE_DIR}/src/MPITransceiver.cpp
111+
${PROJECT_SOURCE_DIR}/src/PVSlice.cpp
112+
${PROJECT_SOURCE_DIR}/src/Registry.cpp
113+
${PROJECT_SOURCE_DIR}/src/Transceiver.cpp
114+
${PROJECT_SOURCE_DIR}/src/jit/mlir.cpp
115+
)
116+
117+
pybind11_add_module(_ddptensor MODULE ${DDPTSrcs} ${Hpps})
118+
add_library(idtr SHARED ${IDTRSrcs} ${Hpps})
119+
set(AllTargets _ddptensor idtr)
89120

90-
target_compile_definitions(_ddptensor PRIVATE USE_MKL=1 DDPT_2TYPES=1)
91-
target_compile_options(_ddptensor PRIVATE "-ftemplate-backtrace-limit=0")
92-
target_include_directories(_ddptensor PRIVATE
121+
add_compile_definitions(USE_MKL=1)
122+
add_compile_options("-ftemplate-backtrace-limit=0")
123+
include_directories(
93124
${PROJECT_SOURCE_DIR}/src/include
94125
${PROJECT_SOURCE_DIR}/third_party/bitsery/include
95126
${MPI_INCLUDE_PATH}
@@ -103,20 +134,22 @@ if (CMAKE_SYSTEM_NAME STREQUAL Linux)
103134
target_link_options(_ddptensor PRIVATE "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/export.txt")
104135
endif()
105136

106-
#target_compile_options(_ddptensor PRIVATE -fopenmp)
137+
#compile_options(_ddptensor PRIVATE -fopenmp)
107138
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
108139
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
109140
get_property(mlir_all_libs GLOBAL PROPERTY MLIR_ALL_LIBS)
110141
get_property(imex_all_libs GLOBAL PROPERTY IMEX_ALL_LIBS)
111142

112143
#llvm_update_compile_flags(_ddpttensor)
113-
target_link_directories(_ddptensor PRIVATE
114-
${CONDA_PREFIX}/lib
115-
${IMEX_INSTALL_PREFIX}/lib
116-
)
144+
target_link_directories(_ddptensor PRIVATE ${CONDA_PREFIX}/lib)
145+
target_link_directories(idtr PRIVATE ${CONDA_PREFIX}/lib ${IMEX_INSTALL_PREFIX}/lib)
117146

118-
message(${imex_all_libs})
119147
target_link_libraries(_ddptensor PRIVATE
148+
# ${MKL_LIBRARIES}
149+
# tbb
150+
idtr
151+
)
152+
target_link_libraries(idtr PRIVATE
120153
${MPI_C_LIBRARIES}
121154
# ${MKL_LIBRARIES}
122155
tbb

export.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
global: PyInit__ddptensor;
3+
local: *;
4+
};

scripts/code_gen.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
// #######################################################
1313
// SPDX-License-Identifier: BSD-3-Clause
1414
#pragma once
15+
#ifdef DEF_PY11_ENUMS
1516
#include <pybind11/pybind11.h>
1617
#include <pybind11/stl.h>
1718
namespace py = pybind11;
19+
#endif
1820
""")
1921

2022
prev = 0
@@ -27,14 +29,16 @@
2729
print(f" {prev}")
2830
print("};\n")
2931

30-
print("static void def_enums(py::module_ & m)\n{")
32+
print("""#ifdef DEF_PY11_ENUMS
33+
static void def_enums(py::module_ & m)
34+
{""")
3135
for cat, lst in api.api_categories.items():
3236
print(f' py::enum_<{cat}Id>(m, "{cat}Id")')
3337
for x in lst:
3438
print(f' .value("{x.upper()}", {x.upper()})')
3539
print(" .export_values();\n")
3640

37-
print("}")
41+
print("}\n#endif\n")
3842

3943
# Close the file
4044
sys.stdout.close()

src/CollComm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// This is not implemented: we need an extra mechanism to work with reshape-views or alike.
1616
std::vector<std::vector<int>> CollComm::map(const PVSlice & n_slc, const PVSlice & o_slc)
1717
{
18-
auto nr = theTransceiver->nranks();
18+
auto nr = getTransceiver()->nranks();
1919
std::vector<int> counts_send(nr, 0);
2020
std::vector<int> disp_send(nr, 0);
2121
std::vector<int> counts_recv(nr, 0);
@@ -26,14 +26,14 @@ std::vector<std::vector<int>> CollComm::map(const PVSlice & n_slc, const PVSlice
2626
// tilesize of my local partition of orig array
2727
auto o_tsz = o_slc.tile_size();
2828
// linearized local slice of orig array
29-
auto o_llslc = Slice(o_ntsz * theTransceiver->rank(), o_ntsz * theTransceiver->rank() + o_tsz);
29+
auto o_llslc = Slice(o_ntsz * getTransceiver()->rank(), o_ntsz * getTransceiver()->rank() + o_tsz);
3030

3131
// norm tile-size of new (reshaped) array
3232
auto n_ntsz = n_slc.tile_size(0);
3333
// tilesize of my local partition of new (reshaped) array
3434
auto n_tsz = n_slc.tile_size();
3535
// linearized/flattened/1d local slice of new (reshaped) array
36-
auto n_llslc = Slice(n_ntsz * theTransceiver->rank(), n_ntsz * theTransceiver->rank() + n_tsz);
36+
auto n_llslc = Slice(n_ntsz * getTransceiver()->rank(), n_ntsz * getTransceiver()->rank() + n_tsz);
3737

3838
for(auto r=0; r<nr; ++r) {
3939
// determine what I receive from rank r

src/Deferred.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ void push_runable(Runable::ptr_type && r)
2222

2323
void _dist(const Runable * p)
2424
{
25-
if(is_cw() && theTransceiver->rank() == 0)
26-
theMediator->to_workers(p);
25+
if(getTransceiver()->is_cw() && getTransceiver()->rank() == 0)
26+
getMediator()->to_workers(p);
2727
}
2828

2929
Deferred::future_type Deferred::get_future()

src/EWBinOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ namespace x {
166166

167167
// Step 1: Get the mapping of a and b to our resulting slice
168168

169-
rank_type rank = theTransceiver->rank();
169+
rank_type rank = getTransceiver()->rank();
170170
// Slice for/of the result
171171
PVSlice r_slc(a_sptr->slice().size() >= b_sptr->slice().size() ? a_sptr->slice().shape() : b_sptr->slice().shape());
172172
// Size of local result tile

src/IO.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct DeferredToNumpy : public DeferredT<promise_type, future_type>
1919
void run()
2020
{
2121
const auto a = std::move(Registry::get(_a).get());
22-
set_value(GetItem::do_gather(a, is_cw() ? 0 : REPLICATED));
22+
set_value(GetItem::do_gather(a, getTransceiver()->is_cw() ? 0 : REPLICATED));
2323
}
2424

2525
FactoryId factory() const
@@ -36,7 +36,7 @@ struct DeferredToNumpy : public DeferredT<promise_type, future_type>
3636

3737
py::object IO::to_numpy(const ddptensor & a)
3838
{
39-
assert(!is_cw() || theTransceiver->rank() == 0);
39+
assert(!getTransceiver()->is_cw() || getTransceiver()->rank() == 0);
4040
auto f = defer<DeferredToNumpy>(a.get());
4141
auto x = f.get();
4242
return x;

src/LinAlgOp.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ namespace x {
4848
static ptr_type vecdot_1d(const T1 & a, const T2 & b, int axis)
4949
{
5050
auto d = xt::linalg::dot(a, b)();
51-
theTransceiver->reduce_all(&d, DTYPE<decltype(d)>::value, 1, SUM);
51+
getTransceiver()->reduce_all(&d, DTYPE<decltype(d)>::value, 1, SUM);
5252
return operatorx<decltype(d)>::mk_tx(d, REPLICATED);
5353
}
5454

@@ -58,8 +58,8 @@ namespace x {
5858
if(a_ptr->slice().split_dim() != 0)
5959
throw(std::runtime_error("vecdoc_2d supported for split_dim=0 only"));
6060

61-
auto nr = theTransceiver->nranks();
62-
auto me = theTransceiver->rank();
61+
auto nr = getTransceiver()->nranks();
62+
auto me = getTransceiver()->rank();
6363
rank_type right = (me + 1) % nr;
6464
rank_type left = (nr + me - 1) % nr;
6565
auto tsz = b_ptr->slice().tile_size(0);
@@ -97,7 +97,7 @@ namespace x {
9797
if(i > 1) {
9898
// data exchange
9999
// FIXME: optimize data transfer: last partition might contain unused data
100-
theTransceiver->send_recv(buff.data(),
100+
getTransceiver()->send_recv(buff.data(),
101101
tsz,
102102
DTYPE<A>::value,
103103
left,

src/MPIMediator.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <unordered_map>
77
#include <mutex>
88

9-
#include "ddptensor/UtilsAndTypes.hpp"
9+
#include "ddptensor/CppTypes.hpp"
1010
#include "ddptensor/MPIMediator.hpp"
1111
#include "ddptensor/MPITransceiver.hpp"
1212
#include "ddptensor/NDSlice.hpp"
@@ -24,7 +24,7 @@ void send_to_workers(const Runable * dfrd, bool self, MPI_Comm comm);
2424
MPIMediator::MPIMediator()
2525
: _listener(nullptr)
2626
{
27-
auto c = dynamic_cast<MPITransceiver*>(theTransceiver);
27+
auto c = dynamic_cast<MPITransceiver*>(getTransceiver());
2828
if(c == nullptr) throw std::runtime_error("Expected Transceiver to be MPITransceiver.");
2929
_comm = c->comm();
3030
int sz;
@@ -40,9 +40,9 @@ MPIMediator::~MPIMediator()
4040
MPI_Comm_rank(_comm, &rank);
4141
MPI_Comm_size(_comm, &sz);
4242

43-
if(is_cw() && rank == 0) to_workers(nullptr);
43+
if(getTransceiver()->is_cw() && rank == 0) to_workers(nullptr);
4444
MPI_Barrier(_comm);
45-
if(!is_cw() || rank == 0) send_to_workers(nullptr, true, _comm);
45+
if(!getTransceiver()->is_cw() || rank == 0) send_to_workers(nullptr, true, _comm);
4646
if(_listener) {
4747
_listener->join();
4848
delete _listener;

src/MPITransceiver.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
#include <mpi.h>
44
#include <limits>
55
#include <sstream>
6+
#include <iostream>
67
#include "ddptensor/MPITransceiver.hpp"
78

8-
MPITransceiver::MPITransceiver()
9-
: _nranks(1), _rank(0), _comm(MPI_COMM_WORLD)
9+
MPITransceiver::MPITransceiver(bool is_cw)
10+
: _nranks(1), _rank(0), _comm(MPI_COMM_WORLD), _is_cw(is_cw)
1011
{
1112
int flag;
1213
MPI_Initialized(&flag);

0 commit comments

Comments
 (0)