Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 47a4282

Browse files
authored
simple fallback to numpy (#39)
* simple fallback to numpy * fallback only if DDPT_FALLBACK is set (to a fallback lib)
1 parent bb83030 commit 47a4282

File tree

3 files changed

+52
-10
lines changed

3 files changed

+52
-10
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ jobs:
146146
. $GITHUB_WORKSPACE/third_party/install/miniconda/etc/profile.d/conda.sh
147147
. $GITHUB_WORKSPACE/third_party/install/miniconda/bin/activate ddpt
148148
cd examples
149-
python -u ./stencil-2d.py 5 1024 star 4
150-
DDPT_FORCE_DIST=1 python -u ./stencil-2d.py 5 1024 star 4
149+
DDPT_FALLBACK=numpy python -u ./stencil-2d.py 5 1024 star 4
150+
DDPT_FALLBACK=numpy DDPT_FORCE_DIST=1 python -u ./stencil-2d.py 5 1024 star 4
151151
python -u ./wave_equation.py -ct
152152
DDPT_FORCE_DIST=1 python -u ./wave_equation.py -ct
153153
cd -

ddptensor/__init__.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# are simply forwarded as-is.
1515

1616
_bool = bool
17+
from typing import Any
1718
from . import _ddptensor as _cdt
1819
from ._ddptensor import (
1920
FLOAT64 as float64,
@@ -34,6 +35,7 @@
3435

3536
from .ddptensor import dtensor
3637
from os import getenv
38+
from importlib import import_module
3739
from . import array_api as api
3840
from . import spmd
3941

@@ -96,12 +98,12 @@ def to_numpy(a):
9698
f"{func} = lambda this, dim=None: dtensor(_cdt.ReduceOp.op(_cdt.{FUNC}, this._t, dim if dim else []))"
9799
)
98100

99-
for func in api.api_categories["ManipOp"]:
100-
FUNC = func.upper()
101-
if func == "reshape":
102-
exec(
103-
f"{func} = lambda this, /, shape, *, copy=None: dtensor(_cdt.ManipOp.reshape(this._t, shape, copy))"
104-
)
101+
# for func in api.api_categories["ManipOp"]:
102+
# FUNC = func.upper()
103+
# if func == "reshape":
104+
# exec(
105+
# f"{func} = lambda this, /, shape, *, copy=None: dtensor(_cdt.ManipOp.reshape(this._t, shape, copy))"
106+
# )
105107

106108
for func in api.api_categories["LinAlgOp"]:
107109
FUNC = func.upper()
@@ -118,3 +120,44 @@ def to_numpy(a):
118120
)
119121
elif func == "matrix_transpose":
120122
exec(f"{func} = lambda this: dtensor(_cdt.LinAlgOp.{func}(this._t))")
123+
124+
125+
_fb_env = getenv("DDPT_FALLBACK")
126+
if _fb_env is not None:
127+
128+
class _fallback:
129+
"Fallback to whatever is provided in DDPT_FALLBACK"
130+
_fb_lib = import_module(_fb_env)
131+
132+
def __init__(self, fname: str, mod=None) -> None:
133+
"get callable with name 'fname' from fallback-lib or throw exception"
134+
self._mod = mod if mod else _fallback._fb_lib
135+
self._func = getattr(self._mod, fname)
136+
137+
def __call__(self, *args: Any, **kwds: Any) -> Any:
138+
"convert ddptensors args to fallback arrays, call fallback-lib and return converted ddptensor"
139+
nargs = []
140+
nkwds = {}
141+
for arg in args:
142+
nargs.append(
143+
spmd.get_locals(arg)[0] if isinstance(arg, dtensor) else arg
144+
)
145+
for k, v in kwds.items():
146+
nkwds[k] = spmd.get_locals(v)[0] if isinstance(v, dtensor) else v
147+
148+
res = self._func(*nargs, **nkwds)
149+
return (
150+
spmd.from_locals(res)
151+
if isinstance(res, _fallback._fb_lib.ndarray)
152+
else res
153+
)
154+
155+
def __getattr__(self, name):
156+
"""Attempt to find a fallback in current fallback object.
157+
This might be necessary if we call something like dt.linalg.norm(...)
158+
"""
159+
return _fallback(name, self._func)
160+
161+
def __getattr__(name):
162+
"Attempt to find a fallback in fallback-lib"
163+
return _fallback(name)

examples/stencil-2d.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
else:
6666
from timeit import default_timer as timer
6767

68-
import numpy
6968
import ddptensor as np
7069
import ddptensor.numpy
7170

@@ -192,7 +191,7 @@ def main():
192191
# ******************************************************************************
193192

194193
B = np.spmd.gather(B)
195-
norm = numpy.linalg.norm(numpy.reshape(B, n * n), ord=1)
194+
norm = np.linalg.norm(np.reshape(B, n * n), ord=1)
196195
active_points = (n - 2 * r) ** 2
197196
norm /= active_points
198197

0 commit comments

Comments
 (0)