|
| 1 | +from . import _ddptensor as _cdt |
| 2 | +from .ddptensor import float64, int64, fini, dtensor |
| 3 | +from os import getenv |
| 4 | + |
| 5 | +__impl_str = getenv("DDPNP_ARRAY", 'numpy') |
| 6 | +exec(f"import {__impl_str} as __impl") |
| 7 | + |
| 8 | +ew_binary_ops = [ |
| 9 | + "add", # (x1, x2, /) |
| 10 | + "atan2", # (x1, x2, /) |
| 11 | + "bitwise_and", # (x1, x2, /) |
| 12 | + "bitwise_left_shift", # (x1, x2, /) |
| 13 | + "bitwise_or", # (x1, x2, /) |
| 14 | + "bitwise_right_shift", # (x1, x2, /) |
| 15 | + "bitwise_xor", # (x1, x2, /) |
| 16 | + "divide", # (x1, x2, /) |
| 17 | + "equal", # (x1, x2, /) |
| 18 | + "floor_divide", # (x1, x2, /) |
| 19 | + "greater", # (x1, x2, /) |
| 20 | + "greater_equal", # (x1, x2, /) |
| 21 | + "less_equal", # (x1, x2, /) |
| 22 | + "logaddexp", # (x1, x2) |
| 23 | + "logical_and", # (x1, x2, /) |
| 24 | + "logical_or", # (x1, x2, /) |
| 25 | + "logical_xor", # (x1, x2, /) |
| 26 | + "multiply", # (x1, x2, /) |
| 27 | + "less", # (x1, x2, /) |
| 28 | + "not_equal", # (x1, x2, /) |
| 29 | + "pow", # (x1, x2, /) |
| 30 | + "remainder", # (x1, x2, /) |
| 31 | + "subtract", # (x1, x2, /) |
| 32 | +] |
| 33 | + |
| 34 | +for op in ew_binary_ops: |
| 35 | + exec( |
| 36 | + f"{op} = lambda this, other: dtensor(_cdt.ew_binary_op(this._t, '{op}', other._t if isinstance(other, ddptensor) else other, False))" |
| 37 | + ) |
| 38 | + |
| 39 | +ew_unary_ops = [ |
| 40 | + "abs", # (x, /) |
| 41 | + "acos", # (x, /) |
| 42 | + "acosh", # (x, /) |
| 43 | + "asin", # (x, /) |
| 44 | + "asinh", # (x, /) |
| 45 | + "atan", # (x, /) |
| 46 | + "atanh", # (x, /) |
| 47 | + "bitwise_invert", # (x, /) |
| 48 | + "ceil", # (x, /) |
| 49 | + "cos", # (x, /) |
| 50 | + "cosh", # (x, /) |
| 51 | + "exp", # (x, /) |
| 52 | + "expm1", # (x, /) |
| 53 | + "floor", # (x, /) |
| 54 | + "isfinite", # (x, /) |
| 55 | + "isinf", # (x, /) |
| 56 | + "isnan", # (x, /) |
| 57 | + "logical_not", # (x, /) |
| 58 | + "log", # (x, /) |
| 59 | + "log1p", # (x, /) |
| 60 | + "log2", # (x, /) |
| 61 | + "log10", # (x, /) |
| 62 | + "negative", # (x, /) |
| 63 | + "positive", # (x, /) |
| 64 | + "round", # (x, /) |
| 65 | + "sign", # (x, /) |
| 66 | + "sin", # (x, /) |
| 67 | + "sinh", # (x, /) |
| 68 | + "square", # (x, /) |
| 69 | + "sqrt", # (x, /) |
| 70 | + "tan", # (x, /) |
| 71 | + "tanh", # (x, /) |
| 72 | + "trunc", # (x, /) |
| 73 | +] |
| 74 | + |
| 75 | +for op in ew_unary_ops: |
| 76 | + exec( |
| 77 | + f"{op} = lambda this: dtensor(_cdt.ew_unary_op(this._t, '{op}', False))" |
| 78 | + ) |
| 79 | + |
| 80 | +creators_with_shape = [ |
| 81 | + "empty", # (shape, *, dtype=None, device=None) |
| 82 | + "full", # (shape, fill_value, *, dtype=None, device=None) |
| 83 | + "ones", # (shape, *, dtype=None, device=None) |
| 84 | + "zeros", # (shape, *, dtype=None, device=None) |
| 85 | +] |
| 86 | + |
| 87 | +for func in creators_with_shape: |
| 88 | + exec( |
| 89 | + f"{func} = lambda shape, *args, **kwargs: dtensor(_cdt.create(shape, '{func}', '{__impl_str}', *args, **kwargs))" |
| 90 | + ) |
| 91 | + |
| 92 | +statisticals = [ |
| 93 | + "max", # (x, /, *, axis=None, keepdims=False) |
| 94 | + "mean", # (x, /, *, axis=None, keepdims=False) |
| 95 | + "min", # (x, /, *, axis=None, keepdims=False) |
| 96 | + "prod", # (x, /, *, axis=None, keepdims=False) |
| 97 | + "sum", # (x, /, *, axis=None, keepdims=False) |
| 98 | + "std", # (x, /, *, axis=None, correction=0.0, keepdims=False) |
| 99 | + "var", # (x, /, *, axis=None, correction=0.0, keepdims=False) |
| 100 | +] |
| 101 | + |
| 102 | +for func in statisticals: |
| 103 | + exec( |
| 104 | + f"{func} = lambda this, **kwargs: dtensor(_cdt.reduce_op(this._t, '{func}', **kwargs))" |
| 105 | + ) |
| 106 | + |
| 107 | + |
| 108 | +creators = [ |
| 109 | + "arange", # (start, /, stop=None, step=1, *, dtype=None, device=None) |
| 110 | + "asarray", # (obj, /, *, dtype=None, device=None, copy=None) |
| 111 | + "empty_like", # (x, /, *, dtype=None, device=None) |
| 112 | + "eye", # (n_rows, n_cols=None, /, *, k=0, dtype=None, device=None) |
| 113 | + "from_dlpack", # (x, /) |
| 114 | + "full_like", # (x, /, fill_value, *, dtype=None, device=None) |
| 115 | + "linspace", # (start, stop, /, num, *, dtype=None, device=None, endpoint=True) |
| 116 | + "meshgrid", # (*arrays, indexing=’xy’) |
| 117 | + "ones_like", # (x, /, *, dtype=None, device=None) |
| 118 | + "zeros_like", # (x, /, *, dtype=None, device=None) |
| 119 | +] |
0 commit comments