Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 37871c0

Browse files
committed
allow binary ops with different dtypes, cmake fixes
1 parent 9e34a0b commit 37871c0

File tree

13 files changed

+283
-232
lines changed

13 files changed

+283
-232
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
# emacs
132+
*~

CMakeLists.txt

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
1010
# Common installation directories
1111
#include(GNUInstallDirs)
1212

13+
# ===============
14+
# Deps
15+
# ===============
16+
17+
# Find Python3 and NumPy
18+
find_package(Python3 COMPONENTS Interpreter Development.Module NumPy REQUIRED)
19+
20+
find_package(Python COMPONENTS Interpreter Development)
21+
find_package(pybind11 CONFIG)
22+
find_package(MPI REQUIRED)
23+
include_directories(SYSTEM ${MPI_INCLUDE_PATH} ${pybind11_INCLUDE_DIRS})
24+
1325
# Use -fPIC even if statically compiled
1426
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
1527

@@ -19,18 +31,9 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
1931
FILE(GLOB MyCppSources ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/include/ddptensor/*.hpp)
2032

2133
# Create the mymath library
22-
add_library(_ddptensor SHARED ${MyCppSources})
34+
#add_library(_ddptensor MODULE ${MyCppSources})
35+
pybind11_add_module(_ddptensor THIN_LTO ${MyCppSources})
2336

2437
target_include_directories(_ddptensor PRIVATE ${PROJECT_SOURCE_DIR}/src/include ${PROJECT_SOURCE_DIR}/third_party/xtl/include ${PROJECT_SOURCE_DIR}/third_party/xsimd/include ${PROJECT_SOURCE_DIR}/third_party/xtensor-blas/include ${PROJECT_SOURCE_DIR}/third_party/xtensor/include ${PROJECT_SOURCE_DIR}/third_party/bitsery/include)
2538

26-
# ===============
27-
# Deps
28-
# ===============
29-
30-
# Find Python3 and NumPy
31-
find_package(Python3 COMPONENTS Interpreter Development.Module NumPy REQUIRED)
32-
33-
find_package(MPI REQUIRED)
34-
find_package(pybind11 CONFIG)
35-
include_directories(SYSTEM ${MPI_INCLUDE_PATH} ${pybind11_INCLUDE_DIRS})
36-
target_link_libraries(_ddptensor ${MPI_C_LIBRARIES})
39+
target_link_libraries(_ddptensor PRIVATE ${MPI_C_LIBRARIES})

ddptensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
# are simply forwarded as-is.
1515

1616
from . import _ddptensor as _cdt
17-
from .ddptensor import float64, int64, fini, dtensor
17+
from ._ddptensor import float64, float32, int64, int32, int16, uint64, uint32, uint16, fini
18+
from .ddptensor import dtensor
1819
from os import getenv
1920
from . import array_api as api
2021
from . import spmd

setup.py

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,62 @@
1-
import cmake_build_extension
2-
from setuptools import setup
3-
from pathlib import Path
4-
5-
ext_modules = [
6-
cmake_build_extension.CMakeExtension(
7-
name="_ddptensor",
8-
# Name of the resulting package name (import mymath_pybind11)
9-
install_prefix="ddptensor",
10-
# Note: pybind11 is a build-system requirement specified in pyproject.toml,
11-
# therefore pypa/pip or pypa/build will install it in the virtual
12-
# environment created in /tmp during packaging.
13-
# This cmake_depends_on option adds the pybind11 installation path
14-
# to CMAKE_PREFIX_PATH so that the example finds the pybind11 targets
15-
# even if it is not installed in the system.
16-
cmake_depends_on=["pybind11"],
17-
# Exposes the binary print_answer to the environment.
18-
# It requires also adding a new entry point in setup.cfg.
19-
# expose_binaries=["bin/print_answer"],
20-
# Writes the content to the top-level __init__.py
21-
#write_top_level_init=init_py,
22-
# Selects the folder where the main CMakeLists.txt is stored
23-
# (it could be a subfolder)
24-
source_dir=str(Path(__file__).parent.absolute()),
25-
cmake_configure_options=[
26-
]
27-
),
28-
]
1+
import os
2+
import pathlib
3+
from setuptools import setup, Extension
4+
from setuptools.command.build_ext import build_ext as build_ext_orig
5+
6+
7+
class CMakeExtension(Extension):
8+
9+
def __init__(self, name):
10+
# don't invoke the original build_ext for this special extension
11+
super().__init__(name, sources=[])
12+
13+
14+
class build_ext(build_ext_orig):
15+
16+
def run(self):
17+
for ext in self.extensions:
18+
self.build_cmake(ext)
19+
super().run()
20+
21+
def build_cmake(self, ext):
22+
cwd = pathlib.Path().absolute()
23+
24+
# these dirs will be created in build_py, so if you don't have
25+
# any python sources to bundle, the dirs will be missing
26+
build_temp = pathlib.Path(self.build_temp)
27+
build_temp.mkdir(parents=True, exist_ok=True)
28+
extdir = pathlib.Path(self.get_ext_fullpath(ext.name))
29+
extdir.parent.mkdir(parents=True, exist_ok=True)
30+
31+
# example of cmake args
32+
config = 'Debug' if self.debug else 'Release'
33+
cmake_args = [
34+
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(extdir.parent.absolute()),
35+
'-DCMAKE_BUILD_TYPE=' + config
36+
]
37+
38+
# example of build args
39+
build_args = [
40+
'--config', config,
41+
'--', '-j8'
42+
]
43+
44+
os.chdir(str(build_temp))
45+
self.spawn(['cmake', str(cwd)] + cmake_args)
46+
if not self.dry_run:
47+
self.spawn(['cmake', '--build', '.'] + build_args)
48+
# Troubleshooting: if fail on line above then delete all possible
49+
# temporary CMake files including "CMakeCache.txt" in top level dir.
50+
os.chdir(str(cwd))
51+
2952

3053
setup(name="ddptensor",
3154
version="0.1",
3255
description="Distributed Tensor and more",
3356
packages=["ddptensor", "ddptensor.numpy", "ddptensor.torch"],
34-
ext_modules=ext_modules,
57+
ext_modules=[CMakeExtension('ddptensor/_ddptensor')],
3558
cmdclass=dict(
3659
# Enable the CMakeExtension entries defined above
37-
build_ext=cmake_build_extension.BuildExtension,
60+
build_ext=build_ext #cmake_build_extension.BuildExtension,
3861
),
3962
)

src/Creator.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace x {
88
{
99
public:
1010
using ptr_type = DPTensorBaseX::ptr_type;
11+
using typed_ptr_type = typename DPTensorX<T>::typed_ptr_type;
1112

1213
static ptr_type op(CreatorId c, shape_type && shp)
1314
{

src/EWBinOp.cpp

Lines changed: 57 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,114 +3,64 @@
33

44
namespace x {
55

6-
template<typename T>
76
class EWBinOp
87
{
98
public:
109
using ptr_type = DPTensorBaseX::ptr_type;
1110

1211
#pragma GCC diagnostic ignored "-Wswitch"
13-
14-
template<typename A, typename B, typename U = T, std::enable_if_t<std::is_floating_point<U>::value, bool> = true>
15-
static ptr_type integral_op(EWBinOpId iop, const DPTensorX<T> & tx, A && a, B && b)
16-
{
17-
throw std::runtime_error("Illegal or unknown inplace elementwise binary operation");
18-
}
19-
20-
template<typename A, typename B, typename U = T, std::enable_if_t<std::is_integral<U>::value, bool> = true>
21-
static ptr_type integral_op(EWBinOpId iop, const DPTensorX<T> & tx, A && a, B && b)
22-
{
23-
switch(iop) {
24-
case __AND__:
25-
case BITWISE_AND:
26-
return operatorx<T>::mk_tx_(tx, a & b);
27-
case __RAND__:
28-
return operatorx<T>::mk_tx_(tx, b & a);
29-
case __LSHIFT__:
30-
case BITWISE_LEFT_SHIFT:
31-
return operatorx<T>::mk_tx_(tx, a << b);
32-
case __MOD__:
33-
case REMAINDER:
34-
return operatorx<T>::mk_tx_(tx, a % b);
35-
case __OR__:
36-
case BITWISE_OR:
37-
return operatorx<T>::mk_tx_(tx, a | b);
38-
case __ROR__:
39-
return operatorx<T>::mk_tx_(tx, b | a);
40-
case __RSHIFT__:
41-
case BITWISE_RIGHT_SHIFT:
42-
return operatorx<T>::mk_tx_(tx, a >> b);
43-
case __XOR__:
44-
case BITWISE_XOR:
45-
return operatorx<T>::mk_tx_(tx, a ^ b);
46-
case __RXOR__:
47-
return operatorx<T>::mk_tx_(tx, b ^ a);
48-
case __RLSHIFT__:
49-
return operatorx<T>::mk_tx_(tx, b << a);
50-
case __RMOD__:
51-
return operatorx<T>::mk_tx_(tx, b % a);
52-
case __RRSHIFT__:
53-
return operatorx<T>::mk_tx_(tx, b >> a);
54-
default:
55-
throw std::runtime_error("Unknown elementwise binary operation");
56-
}
57-
}
58-
59-
static ptr_type op(EWBinOpId bop, const ptr_type & a_ptr, const ptr_type & b_ptr)
12+
template<typename A, typename B>
13+
static ptr_type op(EWBinOpId bop, const std::shared_ptr<DPTensorX<A>> & a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr)
6014
{
61-
const auto _a = dynamic_cast<DPTensorX<T>*>(a_ptr.get());
62-
const auto _b = dynamic_cast<DPTensorX<T>*>(b_ptr.get());
63-
if(!_a || !_b)
64-
throw std::runtime_error("Invalid array object: could not dynamically cast");
65-
const auto & a = xt::strided_view(_a->xarray(), _a->lslice());
66-
const auto & b = xt::strided_view(_b->xarray(), _b->lslice());
15+
const auto & a = xt::strided_view(a_ptr->xarray(), a_ptr->lslice());
16+
const auto & b = xt::strided_view(b_ptr->xarray(), b_ptr->lslice());
6717

6818
switch(bop) {
6919
case __ADD__:
7020
case ADD:
71-
return operatorx<T>::mk_tx_(*_a, a + b);
21+
return operatorx<A>::mk_tx_(a_ptr, a + b);
7222
case __RADD__:
73-
return operatorx<T>::mk_tx_(*_a, b + a);
23+
return operatorx<A>::mk_tx_(a_ptr, b + a);
7424
case ATAN2:
75-
return operatorx<T>::mk_tx_(*_a, xt::atan2(a, b));
25+
return operatorx<A>::mk_tx_(a_ptr, xt::atan2(a, b));
7626
case __EQ__:
7727
case EQUAL:
78-
return operatorx<T>::mk_tx_(*_a, xt::equal(a, b));
28+
return operatorx<A>::mk_tx_(a_ptr, xt::equal(a, b));
7929
case __FLOORDIV__:
8030
case FLOOR_DIVIDE:
81-
return operatorx<T>::mk_tx_(*_a, xt::floor(a / b));
31+
return operatorx<A>::mk_tx_(a_ptr, xt::floor(a / b));
8232
case __GE__:
8333
case GREATER_EQUAL:
84-
return operatorx<T>::mk_tx_(*_a, a >= b);
34+
return operatorx<A>::mk_tx_(a_ptr, a >= b);
8535
case __GT__:
8636
case GREATER:
87-
return operatorx<T>::mk_tx_(*_a, a > b);
37+
return operatorx<A>::mk_tx_(a_ptr, a > b);
8838
case __LE__:
8939
case LESS_EQUAL:
90-
return operatorx<T>::mk_tx_(*_a, a <= b);
40+
return operatorx<A>::mk_tx_(a_ptr, a <= b);
9141
case __LT__:
9242
case LESS:
93-
return operatorx<T>::mk_tx_(*_a, a < b);
43+
return operatorx<A>::mk_tx_(a_ptr, a < b);
9444
case __MUL__:
9545
case MULTIPLY:
96-
return operatorx<T>::mk_tx_(*_a, a * b);
46+
return operatorx<A>::mk_tx_(a_ptr, a * b);
9747
case __RMUL__:
98-
return operatorx<T>::mk_tx_(*_a, b * a);
48+
return operatorx<A>::mk_tx_(a_ptr, b * a);
9949
case __NE__:
10050
case NOT_EQUAL:
101-
return operatorx<T>::mk_tx_(*_a, xt::not_equal(a, b));
51+
return operatorx<A>::mk_tx_(a_ptr, xt::not_equal(a, b));
10252
case __SUB__:
10353
case SUBTRACT:
104-
return operatorx<T>::mk_tx_(*_a, a - b);
54+
return operatorx<A>::mk_tx_(a_ptr, a - b);
10555
case __TRUEDIV__:
10656
case DIVIDE:
107-
return operatorx<T>::mk_tx_(*_a, a / b);
57+
return operatorx<A>::mk_tx_(a_ptr, a / b);
10858
case __RFLOORDIV__:
109-
return operatorx<T>::mk_tx_(*_a, xt::floor(b / a));
59+
return operatorx<A>::mk_tx_(a_ptr, xt::floor(b / a));
11060
case __RSUB__:
111-
return operatorx<T>::mk_tx_(*_a, b - a);
61+
return operatorx<A>::mk_tx_(a_ptr, b - a);
11262
case __RTRUEDIV__:
113-
return operatorx<T>::mk_tx_(*_a, b / a);
63+
return operatorx<A>::mk_tx_(a_ptr, b / a);
11464
case __MATMUL__:
11565
case __POW__:
11666
case POW:
@@ -122,15 +72,48 @@ namespace x {
12272
// FIXME
12373
throw std::runtime_error("Binary operation not implemented");
12474
}
125-
return integral_op(bop, *_a, a, b);
75+
if constexpr (std::is_integral<A>::value && std::is_integral<B>::value) {
76+
switch(bop) {
77+
case __AND__:
78+
case BITWISE_AND:
79+
return operatorx<A>::mk_tx_(a_ptr, a & b);
80+
case __RAND__:
81+
return operatorx<A>::mk_tx_(a_ptr, b & a);
82+
case __LSHIFT__:
83+
case BITWISE_LEFT_SHIFT:
84+
return operatorx<A>::mk_tx_(a_ptr, a << b);
85+
case __MOD__:
86+
case REMAINDER:
87+
return operatorx<A>::mk_tx_(a_ptr, a % b);
88+
case __OR__:
89+
case BITWISE_OR:
90+
return operatorx<A>::mk_tx_(a_ptr, a | b);
91+
case __ROR__:
92+
return operatorx<A>::mk_tx_(a_ptr, b | a);
93+
case __RSHIFT__:
94+
case BITWISE_RIGHT_SHIFT:
95+
return operatorx<A>::mk_tx_(a_ptr, a >> b);
96+
case __XOR__:
97+
case BITWISE_XOR:
98+
return operatorx<A>::mk_tx_(a_ptr, a ^ b);
99+
case __RXOR__:
100+
return operatorx<A>::mk_tx_(a_ptr, b ^ a);
101+
case __RLSHIFT__:
102+
return operatorx<A>::mk_tx_(a_ptr, b << a);
103+
case __RMOD__:
104+
return operatorx<A>::mk_tx_(a_ptr, b % a);
105+
case __RRSHIFT__:
106+
return operatorx<A>::mk_tx_(a_ptr, b >> a);
107+
}
108+
}
109+
throw std::runtime_error("Unknown/invalid elementwise binary operation");
126110
}
127-
128111
#pragma GCC diagnostic pop
129112

130113
};
131114
} // namespace x
132115

133116
tensor_i::ptr_type EWBinOp::op(EWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b)
134117
{
135-
return TypeDispatch<x::EWBinOp>(a->dtype(), op, a, b);
118+
return TypeDispatch2<x::EWBinOp>(a, b, op);
136119
}

0 commit comments

Comments
 (0)