Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 077bffd

Browse files
committed
more constistent builds, names and uses of &
1 parent e5b97f0 commit 077bffd

File tree

17 files changed

+342
-684
lines changed

17 files changed

+342
-684
lines changed

CMakeLists.txt

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,32 @@ set(MKL_LIBRARIES -L$ENV{MKLROOT}/lib -lmkl_intel_lp64 -lmkl_intel_thread -lmkl_
2424
# Use -fPIC even if statically compiled
2525
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
2626

27+
set(P2C_HPP ${PROJECT_SOURCE_DIR}/src/include/ddptensor/p2c_ids.hpp)
28+
# Generate enums
29+
add_custom_command(
30+
COMMAND python ${PROJECT_SOURCE_DIR}/scripts/code_gen.py ${PROJECT_SOURCE_DIR}/ddptensor/array_api.py ${P2C_HPP}
31+
DEPENDS ${PROJECT_SOURCE_DIR}/scripts/code_gen.py ${PROJECT_SOURCE_DIR}/ddptensor/array_api.py
32+
OUTPUT ${P2C_HPP}
33+
COMMENT "Generating ${P2C_HPP}."
34+
)
35+
2736
# ============
2837
# Target
2938
# ============
3039
FILE(GLOB MyCppSources ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/include/ddptensor/*.hpp)
40+
set(MyCppSources ${MyCppSources} ${P2C_HPP})
3141

32-
# Create the mymath library
33-
#add_library(_ddptensor MODULE ${MyCppSources})
3442
pybind11_add_module(_ddptensor MODULE ${MyCppSources})
3543

36-
target_compile_options(_ddptensor PRIVATE -fopenmp)
3744
target_compile_definitions(_ddptensor PRIVATE USE_MKL=1 XTENSOR_USE_XSIMD=1 XTENSOR_USE_OPENMP=1 DDPT_2TYPES=1)
38-
target_include_directories(_ddptensor PRIVATE ${PROJECT_SOURCE_DIR}/src/include ${PROJECT_SOURCE_DIR}/third_party/xtl/include ${PROJECT_SOURCE_DIR}/third_party/xsimd/include ${PROJECT_SOURCE_DIR}/third_party/xtensor-blas/include ${PROJECT_SOURCE_DIR}/third_party/xtensor/include ${PROJECT_SOURCE_DIR}/third_party/bitsery/include ${MPI_INCLUDE_PATH} $ENV{MKLROOT}/include ${pybind11_INCLUDE_DIRS})
45+
target_include_directories(_ddptensor PRIVATE
46+
${PROJECT_SOURCE_DIR}/src/include
47+
${PROJECT_SOURCE_DIR}/third_party/xtl/include
48+
${PROJECT_SOURCE_DIR}/third_party/xsimd/include
49+
${PROJECT_SOURCE_DIR}/third_party/xtensor-blas/include
50+
${PROJECT_SOURCE_DIR}/third_party/xtensor/include
51+
${PROJECT_SOURCE_DIR}/third_party/bitsery/include
52+
${MPI_INCLUDE_PATH} $ENV{MKLROOT}/include
53+
${pybind11_INCLUDE_DIRS})
54+
target_compile_options(_ddptensor PRIVATE -fopenmp)
3955
target_link_libraries(_ddptensor PRIVATE ${MPI_C_LIBRARIES} ${MKL_LIBRARIES})

ddptensor/__init__.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,39 @@
1414
# are simply forwarded as-is.
1515

1616
from . import _ddptensor as _cdt
17-
from ._ddptensor import float64, float32, int64, int32, int16, uint64, uint32, uint16, fini
17+
from ._ddptensor import (
18+
FLOAT64 as float64,
19+
FLOAT32 as float32,
20+
INT64 as int64,
21+
INT32 as int32,
22+
INT16 as int16,
23+
INT8 as int8,
24+
UINT64 as uint64,
25+
UINT32 as uint32,
26+
UINT16 as uint16,
27+
UINT8 as uint8,
28+
fini
29+
)
1830
from .ddptensor import dtensor
1931
from os import getenv
2032
from . import array_api as api
2133
from . import spmd
2234

23-
for op in api.ew_binary_ops:
24-
OP = op.upper()
25-
exec(
26-
f"{op} = lambda this, other: dtensor(_cdt.EWBinOp.op(_cdt.{OP}, this._t, other._t if isinstance(other, ddptensor) else other))"
27-
)
35+
for op in api.api_categories["EWBinOp"]:
36+
if not op.startswith("__"):
37+
OP = op.upper()
38+
exec(
39+
f"{op} = lambda this, other: dtensor(_cdt.EWBinOp.op(_cdt.{OP}, this._t, other._t if isinstance(other, ddptensor) else other))"
40+
)
2841

29-
for op in api.ew_unary_ops:
30-
OP = op.upper()
31-
exec(
32-
f"{op} = lambda this: dtensor(_cdt.EWUnyOp.op(_cdt.{OP}, this._t))"
33-
)
42+
for op in api.api_categories["EWUnyOp"]:
43+
if not op.startswith("__"):
44+
OP = op.upper()
45+
exec(
46+
f"{op} = lambda this: dtensor(_cdt.EWUnyOp.op(_cdt.{OP}, this._t))"
47+
)
3448

35-
for func in api.creators:
49+
for func in api.api_categories["Creator"]:
3650
FUNC = func.upper()
3751
if func in ["empty", "ones", "zeros",]:
3852
exec(
@@ -43,7 +57,7 @@
4357
f"{func} = lambda shape, val, dtype: dtensor(_cdt.Creator.full(_cdt.{FUNC}, shape, val, dtype))"
4458
)
4559

46-
for func in api.statisticals:
60+
for func in api.api_categories["ReduceOp"]:
4761
FUNC = func.upper()
4862
exec(
4963
f"{func} = lambda this, dim: dtensor(_cdt.ReduceOp.op(_cdt.{FUNC}, this._t, dim))"

0 commit comments

Comments
 (0)