Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit fc1ac86

Browse files
committed
Initial commit
1 parent 2b6c9c9 commit fc1ac86

23 files changed

+2506
-0
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "third_party/bitsery"]
2+
path = third_party/bitsery
3+
url = https://github.com/fraillt/bitsery

ddptensor/__init__.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from . import _ddptensor as _cdt
2+
from .ddptensor import float64, int64, fini, dtensor
3+
from os import getenv
4+
5+
__impl_str = getenv("DDPNP_ARRAY", 'numpy')
6+
exec(f"import {__impl_str} as __impl")
7+
8+
ew_binary_ops = [
9+
"add", # (x1, x2, /)
10+
"atan2", # (x1, x2, /)
11+
"bitwise_and", # (x1, x2, /)
12+
"bitwise_left_shift", # (x1, x2, /)
13+
"bitwise_or", # (x1, x2, /)
14+
"bitwise_right_shift", # (x1, x2, /)
15+
"bitwise_xor", # (x1, x2, /)
16+
"divide", # (x1, x2, /)
17+
"equal", # (x1, x2, /)
18+
"floor_divide", # (x1, x2, /)
19+
"greater", # (x1, x2, /)
20+
"greater_equal", # (x1, x2, /)
21+
"less_equal", # (x1, x2, /)
22+
"logaddexp", # (x1, x2)
23+
"logical_and", # (x1, x2, /)
24+
"logical_or", # (x1, x2, /)
25+
"logical_xor", # (x1, x2, /)
26+
"multiply", # (x1, x2, /)
27+
"less", # (x1, x2, /)
28+
"not_equal", # (x1, x2, /)
29+
"pow", # (x1, x2, /)
30+
"remainder", # (x1, x2, /)
31+
"subtract", # (x1, x2, /)
32+
]
33+
34+
for op in ew_binary_ops:
35+
exec(
36+
f"{op} = lambda this, other: dtensor(_cdt.ew_binary_op(this._t, '{op}', other._t if isinstance(other, ddptensor) else other, False))"
37+
)
38+
39+
ew_unary_ops = [
40+
"abs", # (x, /)
41+
"acos", # (x, /)
42+
"acosh", # (x, /)
43+
"asin", # (x, /)
44+
"asinh", # (x, /)
45+
"atan", # (x, /)
46+
"atanh", # (x, /)
47+
"bitwise_invert", # (x, /)
48+
"ceil", # (x, /)
49+
"cos", # (x, /)
50+
"cosh", # (x, /)
51+
"exp", # (x, /)
52+
"expm1", # (x, /)
53+
"floor", # (x, /)
54+
"isfinite", # (x, /)
55+
"isinf", # (x, /)
56+
"isnan", # (x, /)
57+
"logical_not", # (x, /)
58+
"log", # (x, /)
59+
"log1p", # (x, /)
60+
"log2", # (x, /)
61+
"log10", # (x, /)
62+
"negative", # (x, /)
63+
"positive", # (x, /)
64+
"round", # (x, /)
65+
"sign", # (x, /)
66+
"sin", # (x, /)
67+
"sinh", # (x, /)
68+
"square", # (x, /)
69+
"sqrt", # (x, /)
70+
"tan", # (x, /)
71+
"tanh", # (x, /)
72+
"trunc", # (x, /)
73+
]
74+
75+
for op in ew_unary_ops:
76+
exec(
77+
f"{op} = lambda this: dtensor(_cdt.ew_unary_op(this._t, '{op}', False))"
78+
)
79+
80+
creators_with_shape = [
81+
"empty", # (shape, *, dtype=None, device=None)
82+
"full", # (shape, fill_value, *, dtype=None, device=None)
83+
"ones", # (shape, *, dtype=None, device=None)
84+
"zeros", # (shape, *, dtype=None, device=None)
85+
]
86+
87+
for func in creators_with_shape:
88+
exec(
89+
f"{func} = lambda shape, *args, **kwargs: dtensor(_cdt.create(shape, '{func}', '{__impl_str}', *args, **kwargs))"
90+
)
91+
92+
statisticals = [
93+
"max", # (x, /, *, axis=None, keepdims=False)
94+
"mean", # (x, /, *, axis=None, keepdims=False)
95+
"min", # (x, /, *, axis=None, keepdims=False)
96+
"prod", # (x, /, *, axis=None, keepdims=False)
97+
"sum", # (x, /, *, axis=None, keepdims=False)
98+
"std", # (x, /, *, axis=None, correction=0.0, keepdims=False)
99+
"var", # (x, /, *, axis=None, correction=0.0, keepdims=False)
100+
]
101+
102+
for func in statisticals:
103+
exec(
104+
f"{func} = lambda this, **kwargs: dtensor(_cdt.reduce_op(this._t, '{func}', **kwargs))"
105+
)
106+
107+
108+
creators = [
109+
"arange", # (start, /, stop=None, step=1, *, dtype=None, device=None)
110+
"asarray", # (obj, /, *, dtype=None, device=None, copy=None)
111+
"empty_like", # (x, /, *, dtype=None, device=None)
112+
"eye", # (n_rows, n_cols=None, /, *, k=0, dtype=None, device=None)
113+
"from_dlpack", # (x, /)
114+
"full_like", # (x, /, fill_value, *, dtype=None, device=None)
115+
"linspace", # (start, stop, /, num, *, dtype=None, device=None, endpoint=True)
116+
"meshgrid", # (*arrays, indexing=’xy’)
117+
"ones_like", # (x, /, *, dtype=None, device=None)
118+
"zeros_like", # (x, /, *, dtype=None, device=None)
119+
]

ddptensor/ddptensor.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from . import _ddptensor as _cdt
2+
from ._ddptensor import float64, int64, fini
3+
4+
ew_binary_methods = [
5+
"__add__", # (self, other, /)
6+
"__and__", # (self, other, /)
7+
"__eq__", # (self, other, /)
8+
"__floordiv__", # (self, other, /)
9+
"__ge__", # (self, other, /)
10+
"__gt__", # (self, other, /)
11+
"__le__", # (self, other, /)
12+
"__lshift__", # (self, other, /)
13+
"__lt__", # (self, other, /)
14+
"__matmul__", # (self, other, /)
15+
"__mod__", # (self, other, /)
16+
"__mul__", # (self, other, /)
17+
"__ne__", # (self, other, /)
18+
"__or__", # (self, other, /)
19+
"__pow__", # (self, other, /)
20+
"__rshift__", # (self, other, /)
21+
"__sub__", # (self, other, /)
22+
"__truediv__", # (self, other, /)
23+
"__xor__", # (self, other, /)
24+
# reflected operators
25+
"__radd__",
26+
"__rand__",
27+
"__rflowdiv__",
28+
"__rlshift__",
29+
"__rmod__",
30+
"__rmul__",
31+
"__ror__",
32+
"__rpow__",
33+
"__rrshift__",
34+
"__rsub__",
35+
"__rtruediv__",
36+
"__rxor__",
37+
]
38+
39+
ew_binary_methods_inplace = [
40+
# inplace operators
41+
"__iadd__",
42+
"__iand__",
43+
"__iflowdiv__",
44+
"__ilshift__",
45+
"__imod__",
46+
"__imul__",
47+
"__ior__",
48+
"__ipow__",
49+
"__irshift__",
50+
"__isub__",
51+
"__itruediv__",
52+
"__ixor__",
53+
]
54+
55+
ew_unary_methods = [
56+
"__abs__", # (self, /)
57+
"__invert__", # (self, /)
58+
"__neg__", # (self, /)
59+
"__pos__", # (self, /)
60+
]
61+
62+
unary_methods = [
63+
# "__array_namespace__", # (self, /, *, api_version=None)
64+
"__bool__", # (self, /)
65+
# "__dlpack__", # (self, /, *, stream=None)
66+
# "__dlpack_device__", # (self, /)
67+
"__float__", # (self, /)
68+
"__int__", # (self, /)
69+
"__len__", # (self, /)
70+
]
71+
72+
t_attributes = ["dtype", "shape", ] #"device", "ndim", "size", "T"]
73+
74+
#def try_except(func, *args, **kwargs):
75+
# try:
76+
# return func(*args, **kwargs)
77+
# except:
78+
# return None
79+
80+
class dtensor:
81+
def __init__(self, t):
82+
self._t = t
83+
84+
def __repr__(self):
85+
return self._t.__repr__()
86+
87+
88+
for method in ew_binary_methods:
89+
exec(
90+
f"{method} = lambda self, other: dtensor(_cdt.ew_binary_op(self._t, '{method}', other._t if isinstance(other, dtensor) else other, True))"
91+
)
92+
93+
for method in ew_binary_methods_inplace:
94+
exec(
95+
f"{method} = lambda self, other: (self, _cdt.ew_binary_op_inplace(self._t, '{method}', other._t if isinstance(other, dtensor) else other))[0]"
96+
)
97+
98+
for method in ew_unary_methods:
99+
exec(
100+
f"{method} = lambda self: dtensor(_cdt.ew_unary_op(self._t, '{method}', True))"
101+
)
102+
103+
for method in unary_methods:
104+
exec(
105+
f"{method} = lambda self: self._t.{method}()"
106+
)
107+
108+
for att in t_attributes:
109+
exec(
110+
f"{att} = property(lambda self: self._t.{att})"
111+
)
112+
113+
def __getitem__(self, *args):
114+
x = self._t.__getitem__(*args)
115+
return dtensor(x)
116+
117+
def __setitem__(self, key, value):
118+
x = self._t.__setitem__(key, value._t if isinstance(value, dtensor) else value)

ddptensor/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import dtensor

ddptensor/numpy/random.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from .. import dist_tensor as _cdt
2+
from .. import __impl
3+
from ..dtensor import dtensor
4+
5+
def seed(s=None):
6+
__impl.random.seed(s + _cdt.myrank() if s else _cdt.myrank())
7+
8+
def _uniform_numpy(shape, start, stop, dtype=None):
9+
return __impl.random.uniform(start, stop, shape)
10+
11+
def uniform(start, stop, shape):
12+
return dtensor(_cdt.create(shape, '_uniform_numpy', 'dtensor.numpy.random', start, stop))
13+
14+
# for func in ["seed", "uniform"]:
15+
# exec(
16+
# f"{func} = staticmethod(lambda shape, *args, **kwargs: dtensor(_cdt.create('{func}', _cdt.__dlp_provider_name + '.random', *args, **kwargs)))"
17+
# )

ddptensor/torch/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from .. import dist_tensor as _cdt
2+
import dtensor
3+
4+
def manual_seed(s=None):
5+
dtensor.__impl.manual_seed(s + _cdt.myrank() if s else _cdt.myrank())
6+
7+
def _rand_torch(shape, *args, **kwargs):
8+
return dtensor.__impl.rand(tuple(shape))
9+
10+
def rand(shape, *args, **kwargs):
11+
return dtensor.dtensor(_cdt.create(shape, '_rand_torch', 'dtensor.torch', *args, **kwargs))
12+
13+
def erf(ary):
14+
return dtensor.dtensor(_cdt.ew_unary_op(ary._t, 'erf', False))

setup.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
from os.path import join as jp
3+
from glob import glob
4+
from setuptools import setup
5+
from pybind11.setup_helpers import Pybind11Extension
6+
7+
mpiroot = os.environ.get('MPIROOT')
8+
9+
ext_modules = [
10+
Pybind11Extension(
11+
"ddptensor._ddptensor",
12+
glob("src/*.cpp"),
13+
include_dirs=[jp(mpiroot, "include"), jp("third_party", "bitsery", "include"), jp("src", "include"), ],
14+
extra_compile_args=["-DUSE_MKL", "-std=c++17", "-Wno-unused-but-set-variable", "-Wno-sign-compare", "-Wno-unused-local-typedefs", "-Wno-reorder", "-O0", "-g"],
15+
libraries=["mpi", "rt", "pthread", "dl", "mkl_intel_lp64", "mkl_intel_thread", "mkl_core", "iomp5", "m"],
16+
library_dirs=[jp(mpiroot, "lib")],
17+
language='c++'
18+
),
19+
]
20+
21+
setup(name="ddptensor",
22+
version="0.1",
23+
description="Distributed Tensor and more",
24+
packages=["ddptensor", "ddptensor.numpy", "ddptensor.torch"],
25+
ext_modules=ext_modules
26+
)

0 commit comments

Comments
 (0)