Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit aadf34e

Browse files
committed
adding random.uniform and seed, better use of cmake (MKL) and more types and modern dispatch
1 parent 37871c0 commit aadf34e

File tree

20 files changed

+187
-90
lines changed

20 files changed

+187
-90
lines changed

CMakeLists.txt

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
1616

1717
# Find Python3 and NumPy
1818
find_package(Python3 COMPONENTS Interpreter Development.Module NumPy REQUIRED)
19-
20-
find_package(Python COMPONENTS Interpreter Development)
2119
find_package(pybind11 CONFIG)
2220
find_package(MPI REQUIRED)
23-
include_directories(SYSTEM ${MPI_INCLUDE_PATH} ${pybind11_INCLUDE_DIRS})
2421

22+
set(MKL_LIBRARIES -L$ENV{MKLROOT}/lib -lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core -liomp5 -lpthread -lrt -ldl -lm)
2523
# Use -fPIC even if statically compiled
2624
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
2725

@@ -34,6 +32,6 @@ FILE(GLOB MyCppSources ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src
3432
#add_library(_ddptensor MODULE ${MyCppSources})
3533
pybind11_add_module(_ddptensor THIN_LTO ${MyCppSources})
3634

37-
target_include_directories(_ddptensor PRIVATE ${PROJECT_SOURCE_DIR}/src/include ${PROJECT_SOURCE_DIR}/third_party/xtl/include ${PROJECT_SOURCE_DIR}/third_party/xsimd/include ${PROJECT_SOURCE_DIR}/third_party/xtensor-blas/include ${PROJECT_SOURCE_DIR}/third_party/xtensor/include ${PROJECT_SOURCE_DIR}/third_party/bitsery/include)
38-
39-
target_link_libraries(_ddptensor PRIVATE ${MPI_C_LIBRARIES})
35+
target_compile_definitions(_ddptensor PRIVATE USE_MKL=1 XTENSOR_USE_XSIMD=1 XTENSOR_USE_OPENMP=1) # DDPT_2TYPES=1)
36+
target_include_directories(_ddptensor PRIVATE ${PROJECT_SOURCE_DIR}/src/include ${PROJECT_SOURCE_DIR}/third_party/xtl/include ${PROJECT_SOURCE_DIR}/third_party/xsimd/include ${PROJECT_SOURCE_DIR}/third_party/xtensor-blas/include ${PROJECT_SOURCE_DIR}/third_party/xtensor/include ${PROJECT_SOURCE_DIR}/third_party/bitsery/include ${MPI_INCLUDE_PATH} $ENV{MKLROOT}/include ${pybind11_INCLUDE_DIRS})
37+
target_link_libraries(_ddptensor PRIVATE ${MPI_C_LIBRARIES} ${MKL_LIBRARIES})

ddptensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
for op in api.ew_binary_ops:
2424
OP = op.upper()
2525
exec(
26-
f"{op} = lambda this, other: dtensor(_cdt.EWBinOp.op(_cdt.{OP}, this._t, other._t))" # if isinstance(other, ddptensor) else other, False))"
26+
f"{op} = lambda this, other: dtensor(_cdt.EWBinOp.op(_cdt.{OP}, this._t, other._t if isinstance(other, ddptensor) else other))"
2727
)
2828

2929
for op in api.ew_unary_ops:

ddptensor/array_api.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,20 @@
44
https://data-apis.org/array-api/latest
55
"""
66

7+
dtypes = [
8+
"float32",
9+
"float64",
10+
"int8",
11+
"int16",
12+
"int32",
13+
"int64",
14+
"uint8",
15+
"uint16",
16+
"uint32",
17+
"uint64",
18+
"bool",
19+
]
20+
721
creators = [
822
"arange", # (start, /, stop=None, step=1, *, dtype=None, device=None)
923
"asarray", # (obj, /, *, dtype=None, device=None, copy=None)
@@ -62,6 +76,8 @@
6276
"tan", # (x, /)
6377
"tanh", # (x, /)
6478
"trunc", # (x, /)
79+
# non standard from here
80+
"erf", # (x, /)
6581
]
6682

6783
ew_binary_methods_inplace = [

ddptensor/ddptensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __repr__(self):
2020
for method in api.ew_binary_methods:
2121
METHOD = method.upper()
2222
exec(
23-
f"{method} = lambda self, other: dtensor(_cdt.EWBinOp.op(_cdt.{METHOD}, self._t, other._t))" # if isinstance(other, dtensor) else other, True))"
23+
f"{method} = lambda self, other: dtensor(_cdt.EWBinOp.op(_cdt.{METHOD}, self._t, other._t if isinstance(other, dtensor) else other))"
2424
)
2525

2626
for method in api.ew_binary_methods_inplace:

ddptensor/random.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from . import _ddptensor as _cdt
2+
from . ddptensor import dtensor
3+
4+
def uniform(low, high, size, dtype=_cdt.float64):
5+
return dtensor(_cdt.Random.uniform(dtype, size, low, high))
6+
7+
def seed(s):
8+
_cdt.Random.seed(s)

scripts/code_gen.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,16 @@
1313
namespace py = pybind11;
1414
""")
1515

16+
# dtypes must go first
17+
print("enum DType {")
18+
for x in api.dtypes:
19+
print(f" DT_{x.upper()},")
20+
print(" DTYPE_LAST")
21+
print("};\n")
22+
1623
print("enum CreatorId : int {")
1724
for x in api.creators:
25+
x = x + " = DTYPE_LAST" if x == api.creators[0] else x
1826
print(f" {x.upper()},")
1927
print(" CREATOR_LAST")
2028
print("};\n")
@@ -51,6 +59,11 @@
5159

5260
print("static void def_enums(py::module_ & m)\n{")
5361

62+
print(' py::enum_<DType>(m, "dtype")')
63+
for x in api.dtypes:
64+
print(f' .value("{x}", DT_{x.upper()})')
65+
print(" .export_values();\n")
66+
5467
print(' py::enum_<CreatorId>(m, "CreatorId")')
5568
for x in api.creators:
5669
print(f' .value("{x.upper()}", {x.upper()})')

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def build_cmake(self, ext):
5353
setup(name="ddptensor",
5454
version="0.1",
5555
description="Distributed Tensor and more",
56-
packages=["ddptensor", "ddptensor.numpy", "ddptensor.torch"],
56+
packages=["ddptensor"], #, "ddptensor.numpy", "ddptensor.torch"],
5757
ext_modules=[CMakeExtension('ddptensor/_ddptensor')],
5858
cmdclass=dict(
5959
# Enable the CMakeExtension entries defined above

src/EWBinOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,5 +115,5 @@ namespace x {
115115

116116
tensor_i::ptr_type EWBinOp::op(EWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b)
117117
{
118-
return TypeDispatch2<x::EWBinOp>(a, b, op);
118+
return TypeDispatch<x::EWBinOp>(a, b, op);
119119
}

src/EWUnyOp.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ namespace x {
7474
return operatorx<T>::mk_tx_(a_ptr, xt::tanh(a));
7575
case TRUNC:
7676
return operatorx<T>::mk_tx_(a_ptr, xt::trunc(a));
77+
case ERF:
78+
return operatorx<T>::mk_tx_(a_ptr, xt::erf(a));
7779
case __NEG__:
7880
case NEGATIVE:
7981
case __POS__:
@@ -98,5 +100,5 @@ namespace x {
98100

99101
tensor_i::ptr_type EWUnyOp::op(EWUnyOpId op, x::DPTensorBaseX::ptr_type a)
100102
{
101-
return TypeDispatch2<x::EWUnyOp>(a, op);
103+
return TypeDispatch<x::EWUnyOp>(a, op);
102104
}

src/IEWBinOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,5 @@ namespace x {
6464

6565
void IEWBinOp::op(IEWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b)
6666
{
67-
TypeDispatch2<x::IEWBinOp>(a, b, op);
67+
TypeDispatch<x::IEWBinOp>(a, b, op);
6868
}

0 commit comments

Comments
 (0)