Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 8c4463d

Browse files
authored
Merge pull request #1 from intel-sandbox/deferred
Deferred
2 parents baec5fa + 3caed78 commit 8c4463d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+4397
-1601
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
# emacs
132+
*~

.gitmodules

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
11
[submodule "third_party/bitsery"]
22
path = third_party/bitsery
33
url = https://github.com/fraillt/bitsery
4+
[submodule "third_party/xtensor"]
5+
path = third_party/xtensor
6+
url = https://github.com/xtensor-stack/xtensor
7+
[submodule "third_party/xtensor-blas"]
8+
path = third_party/xtensor-blas
9+
url = https://github.com/xtensor-stack/xtensor-blas
10+
[submodule "third_party/xsimd"]
11+
path = third_party/xsimd
12+
url = https://github.com/xtensor-stack/xsimd
13+
[submodule "third_party/xtl"]
14+
path = third_party/xtl
15+
url = https://github.com/xtensor-stack/xtl

CMakeLists.txt

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
cmake_minimum_required(VERSION 3.18.2)
2+
project(ddptensor VERSION 1.0)
3+
4+
# C++ standard
5+
set(CMAKE_CXX_STANDARD 17)
6+
set(CMAKE_C_EXTENSIONS OFF)
7+
set(CMAKE_CXX_EXTENSIONS OFF)
8+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
9+
10+
# Common installation directories
11+
#include(GNUInstallDirs)
12+
13+
# ===============
14+
# Deps
15+
# ===============
16+
17+
# Find Python3 and NumPy
18+
find_package(Python3 COMPONENTS Interpreter Development.Module NumPy REQUIRED)
19+
find_package(pybind11 CONFIG)
20+
find_package(MPI REQUIRED)
21+
#find_package(OpenMP)
22+
23+
set(MKL_LIBRARIES -L$ENV{MKLROOT}/lib -lmkl_intel_lp64 -lmkl_tbb_thread -lmkl_core -ltbb -lpthread -lrt -ldl -lm)
24+
#set(CMAKE_INSTALL_RPATH $ENV{MKLROOT}/lib)
25+
# Use -fPIC even if statically compiled
26+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
27+
28+
set(P2C_HPP ${PROJECT_SOURCE_DIR}/src/include/ddptensor/p2c_ids.hpp)
29+
# Generate enums
30+
add_custom_command(
31+
COMMAND python ${PROJECT_SOURCE_DIR}/scripts/code_gen.py ${PROJECT_SOURCE_DIR}/ddptensor/array_api.py ${P2C_HPP}
32+
DEPENDS ${PROJECT_SOURCE_DIR}/scripts/code_gen.py ${PROJECT_SOURCE_DIR}/ddptensor/array_api.py
33+
OUTPUT ${P2C_HPP}
34+
COMMENT "Generating ${P2C_HPP}."
35+
)
36+
37+
# ============
38+
# Target
39+
# ============
40+
FILE(GLOB MyCppSources ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/include/ddptensor/*.hpp)
41+
set(MyCppSources ${MyCppSources} ${P2C_HPP})
42+
43+
pybind11_add_module(_ddptensor MODULE ${MyCppSources})
44+
45+
target_compile_definitions(_ddptensor PRIVATE XTENSOR_USE_XSIMD=1 XTENSOR_USE_TBB=1 DDPT_2TYPES=1 USE_MKL=1)
46+
target_include_directories(_ddptensor PRIVATE
47+
${PROJECT_SOURCE_DIR}/src/include
48+
${PROJECT_SOURCE_DIR}/third_party/xtl/include
49+
${PROJECT_SOURCE_DIR}/third_party/xsimd/include
50+
${PROJECT_SOURCE_DIR}/third_party/xtensor-blas/include
51+
${PROJECT_SOURCE_DIR}/third_party/xtensor/include
52+
${PROJECT_SOURCE_DIR}/third_party/bitsery/include
53+
${MPI_INCLUDE_PATH} $ENV{MKLROOT}/include
54+
${pybind11_INCLUDE_DIRS})
55+
#target_compile_options(_ddptensor PRIVATE -fopenmp)
56+
target_link_libraries(_ddptensor PRIVATE ${MPI_C_LIBRARIES} ${MKL_LIBRARIES})

ddptensor/__init__.py

Lines changed: 91 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,104 @@
1-
from . import _ddptensor as _cdt
2-
from .ddptensor import float64, int64, fini, dtensor
3-
from os import getenv
1+
"""
2+
Distributed implementation of the array API as defined here:
3+
https://data-apis.org/array-api/latest
4+
"""
45

5-
__impl_str = getenv("DDPNP_ARRAY", 'numpy')
6-
exec(f"import {__impl_str} as __impl")
6+
# Many features of the API are very uniformly defined.
7+
# We make use of that by providing lists of operations which are similar
8+
# (see array_api.py). __init__.py and ddptensor.py simply generate the API
9+
# by iterating through these lists and forwarding the function calls the the
10+
# C++-extension. Python functions are defined and added by using "eval".
11+
# For many operations we assume the C++-extension defines enums which allow
12+
# us identifying each operation.
13+
# At this point there are no checks of input arguments whatsoever, arguments
14+
# are simply forwarded as-is.
715

8-
ew_binary_ops = [
9-
"add", # (x1, x2, /)
10-
"atan2", # (x1, x2, /)
11-
"bitwise_and", # (x1, x2, /)
12-
"bitwise_left_shift", # (x1, x2, /)
13-
"bitwise_or", # (x1, x2, /)
14-
"bitwise_right_shift", # (x1, x2, /)
15-
"bitwise_xor", # (x1, x2, /)
16-
"divide", # (x1, x2, /)
17-
"equal", # (x1, x2, /)
18-
"floor_divide", # (x1, x2, /)
19-
"greater", # (x1, x2, /)
20-
"greater_equal", # (x1, x2, /)
21-
"less_equal", # (x1, x2, /)
22-
"logaddexp", # (x1, x2)
23-
"logical_and", # (x1, x2, /)
24-
"logical_or", # (x1, x2, /)
25-
"logical_xor", # (x1, x2, /)
26-
"multiply", # (x1, x2, /)
27-
"less", # (x1, x2, /)
28-
"not_equal", # (x1, x2, /)
29-
"pow", # (x1, x2, /)
30-
"remainder", # (x1, x2, /)
31-
"subtract", # (x1, x2, /)
32-
]
16+
_bool = bool
17+
from . import _ddptensor as _cdt
18+
from ._ddptensor import (
19+
FLOAT64 as float64,
20+
FLOAT32 as float32,
21+
INT64 as int64,
22+
INT32 as int32,
23+
INT16 as int16,
24+
INT8 as int8,
25+
UINT64 as uint64,
26+
UINT32 as uint32,
27+
UINT16 as uint16,
28+
UINT8 as uint8,
29+
BOOL as bool,
30+
init as _init,
31+
fini,
32+
sync
33+
)
3334

34-
for op in ew_binary_ops:
35-
exec(
36-
f"{op} = lambda this, other: dtensor(_cdt.ew_binary_op(this._t, '{op}', other._t if isinstance(other, ddptensor) else other, False))"
37-
)
35+
from .ddptensor import dtensor
36+
from os import getenv
37+
from . import array_api as api
38+
from . import spmd
3839

39-
ew_unary_ops = [
40-
"abs", # (x, /)
41-
"acos", # (x, /)
42-
"acosh", # (x, /)
43-
"asin", # (x, /)
44-
"asinh", # (x, /)
45-
"atan", # (x, /)
46-
"atanh", # (x, /)
47-
"bitwise_invert", # (x, /)
48-
"ceil", # (x, /)
49-
"cos", # (x, /)
50-
"cosh", # (x, /)
51-
"exp", # (x, /)
52-
"expm1", # (x, /)
53-
"floor", # (x, /)
54-
"isfinite", # (x, /)
55-
"isinf", # (x, /)
56-
"isnan", # (x, /)
57-
"logical_not", # (x, /)
58-
"log", # (x, /)
59-
"log1p", # (x, /)
60-
"log2", # (x, /)
61-
"log10", # (x, /)
62-
"negative", # (x, /)
63-
"positive", # (x, /)
64-
"round", # (x, /)
65-
"sign", # (x, /)
66-
"sin", # (x, /)
67-
"sinh", # (x, /)
68-
"square", # (x, /)
69-
"sqrt", # (x, /)
70-
"tan", # (x, /)
71-
"tanh", # (x, /)
72-
"trunc", # (x, /)
73-
]
40+
_ddpt_cw = _bool(int(getenv('DDPT_CW', True)))
7441

75-
for op in ew_unary_ops:
76-
exec(
77-
f"{op} = lambda this: dtensor(_cdt.ew_unary_op(this._t, '{op}', False))"
78-
)
42+
def init(cw=None):
43+
cw = _ddpt_cw if cw is None else cw
44+
_init(cw)
7945

80-
creators_with_shape = [
81-
"empty", # (shape, *, dtype=None, device=None)
82-
"full", # (shape, fill_value, *, dtype=None, device=None)
83-
"ones", # (shape, *, dtype=None, device=None)
84-
"zeros", # (shape, *, dtype=None, device=None)
85-
]
46+
def to_numpy(a):
47+
return _cdt.to_numpy(a._t)
8648

87-
for func in creators_with_shape:
88-
exec(
89-
f"{func} = lambda shape, *args, **kwargs: dtensor(_cdt.create(shape, '{func}', '{__impl_str}', *args, **kwargs))"
90-
)
49+
for op in api.api_categories["EWBinOp"]:
50+
if not op.startswith("__"):
51+
OP = op.upper()
52+
exec(
53+
f"{op} = lambda this, other: dtensor(_cdt.EWBinOp.op(_cdt.{OP}, this._t, other._t if isinstance(other, ddptensor) else other))"
54+
)
55+
56+
for op in api.api_categories["EWUnyOp"]:
57+
if not op.startswith("__"):
58+
OP = op.upper()
59+
exec(
60+
f"{op} = lambda this: dtensor(_cdt.EWUnyOp.op(_cdt.{OP}, this._t))"
61+
)
9162

92-
statisticals = [
93-
"max", # (x, /, *, axis=None, keepdims=False)
94-
"mean", # (x, /, *, axis=None, keepdims=False)
95-
"min", # (x, /, *, axis=None, keepdims=False)
96-
"prod", # (x, /, *, axis=None, keepdims=False)
97-
"sum", # (x, /, *, axis=None, keepdims=False)
98-
"std", # (x, /, *, axis=None, correction=0.0, keepdims=False)
99-
"var", # (x, /, *, axis=None, correction=0.0, keepdims=False)
100-
]
63+
for func in api.api_categories["Creator"]:
64+
FUNC = func.upper()
65+
if func in ["empty", "ones", "zeros",]:
66+
exec(
67+
f"{func} = lambda shape, dtype: dtensor(_cdt.Creator.create_from_shape(_cdt.{FUNC}, shape, dtype))"
68+
)
69+
elif func == "full":
70+
exec(
71+
f"{func} = lambda shape, val, dtype: dtensor(_cdt.Creator.full(shape, val, dtype))"
72+
)
73+
elif func == "arange":
74+
exec(
75+
f"{func} = lambda start, end, step, dtype: dtensor(_cdt.Creator.arange(start, end, step, dtype))"
76+
)
10177

102-
for func in statisticals:
78+
for func in api.api_categories["ReduceOp"]:
79+
FUNC = func.upper()
10380
exec(
104-
f"{func} = lambda this, **kwargs: dtensor(_cdt.reduce_op(this._t, '{func}', **kwargs))"
81+
f"{func} = lambda this, dim: dtensor(_cdt.ReduceOp.op(_cdt.{FUNC}, this._t, dim))"
10582
)
10683

84+
for func in api.api_categories["ManipOp"]:
85+
FUNC = func.upper()
86+
if func == "reshape":
87+
exec(
88+
f"{func} = lambda this, shape: dtensor(_cdt.ManipOp.reshape(this._t, shape))"
89+
)
10790

108-
creators = [
109-
"arange", # (start, /, stop=None, step=1, *, dtype=None, device=None)
110-
"asarray", # (obj, /, *, dtype=None, device=None, copy=None)
111-
"empty_like", # (x, /, *, dtype=None, device=None)
112-
"eye", # (n_rows, n_cols=None, /, *, k=0, dtype=None, device=None)
113-
"from_dlpack", # (x, /)
114-
"full_like", # (x, /, fill_value, *, dtype=None, device=None)
115-
"linspace", # (start, stop, /, num, *, dtype=None, device=None, endpoint=True)
116-
"meshgrid", # (*arrays, indexing=’xy’)
117-
"ones_like", # (x, /, *, dtype=None, device=None)
118-
"zeros_like", # (x, /, *, dtype=None, device=None)
119-
]
91+
for func in api.api_categories["LinAlgOp"]:
92+
FUNC = func.upper()
93+
if func in ["tensordot", "vecdot",]:
94+
exec(
95+
f"{func} = lambda this, other, axis: dtensor(_cdt.LinAlgOp.{func}(this._t, other._t, axis))"
96+
)
97+
elif func == "matmul":
98+
exec(
99+
f"{func} = lambda this, other: dtensor(_cdt.LinAlgOp.vecdot(this._t, other._t, 0))"
100+
)
101+
elif func == "matrix_transpose":
102+
exec(
103+
f"{func} = lambda this: dtensor(_cdt.LinAlgOp.{func}(this._t))"
104+
)

0 commit comments

Comments
 (0)