1414# are simply forwarded as-is.
1515
1616_bool = bool
17+ from typing import Any
1718from . import _ddptensor as _cdt
1819from ._ddptensor import (
1920 FLOAT64 as float64 ,
3435
3536from .ddptensor import dtensor
3637from os import getenv
38+ from importlib import import_module
3739from . import array_api as api
3840from . import spmd
3941
@@ -96,12 +98,12 @@ def to_numpy(a):
9698 f"{ func } = lambda this, dim=None: dtensor(_cdt.ReduceOp.op(_cdt.{ FUNC } , this._t, dim if dim else []))"
9799 )
98100
99- for func in api .api_categories ["ManipOp" ]:
100- FUNC = func .upper ()
101- if func == "reshape" :
102- exec (
103- f"{ func } = lambda this, /, shape, *, copy=None: dtensor(_cdt.ManipOp.reshape(this._t, shape, copy))"
104- )
101+ # for func in api.api_categories["ManipOp"]:
102+ # FUNC = func.upper()
103+ # if func == "reshape":
104+ # exec(
105+ # f"{func} = lambda this, /, shape, *, copy=None: dtensor(_cdt.ManipOp.reshape(this._t, shape, copy))"
106+ # )
105107
106108for func in api .api_categories ["LinAlgOp" ]:
107109 FUNC = func .upper ()
@@ -118,3 +120,44 @@ def to_numpy(a):
118120 )
119121 elif func == "matrix_transpose" :
120122 exec (f"{ func } = lambda this: dtensor(_cdt.LinAlgOp.{ func } (this._t))" )
123+
124+
125+ _fb_env = getenv ("DDPT_FALLBACK" )
126+ if _fb_env is not None :
127+
128+ class _fallback :
129+ "Fallback to whatever is provided in DDPT_FALLBACK"
130+ _fb_lib = import_module (_fb_env )
131+
132+ def __init__ (self , fname : str , mod = None ) -> None :
133+ "get callable with name 'fname' from fallback-lib or throw exception"
134+ self ._mod = mod if mod else _fallback ._fb_lib
135+ self ._func = getattr (self ._mod , fname )
136+
137+ def __call__ (self , * args : Any , ** kwds : Any ) -> Any :
138+ "convert ddptensors args to fallback arrays, call fallback-lib and return converted ddptensor"
139+ nargs = []
140+ nkwds = {}
141+ for arg in args :
142+ nargs .append (
143+ spmd .get_locals (arg )[0 ] if isinstance (arg , dtensor ) else arg
144+ )
145+ for k , v in kwds .items ():
146+ nkwds [k ] = spmd .get_locals (v )[0 ] if isinstance (v , dtensor ) else v
147+
148+ res = self ._func (* nargs , ** nkwds )
149+ return (
150+ spmd .from_locals (res )
151+ if isinstance (res , _fallback ._fb_lib .ndarray )
152+ else res
153+ )
154+
155+ def __getattr__ (self , name ):
156+ """Attempt to find a fallback in current fallback object.
157+ This might be necessary if we call something like dt.linalg.norm(...)
158+ """
159+ return _fallback (name , self ._func )
160+
161+ def __getattr__ (name ):
162+ "Attempt to find a fallback in fallback-lib"
163+ return _fallback (name )
0 commit comments