Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 5ca13fe

Browse files
committed
allow scalars at right operand to binops
1 parent aadf34e commit 5ca13fe

File tree

6 files changed

+40
-7
lines changed

6 files changed

+40
-7
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def build_cmake(self, ext):
2929
extdir.parent.mkdir(parents=True, exist_ok=True)
3030

3131
# example of cmake args
32-
config = 'Debug' if self.debug else 'Release'
32+
config = 'Debug' # if self.debug else 'Release'
3333
cmake_args = [
3434
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(extdir.parent.absolute()),
3535
'-DCMAKE_BUILD_TYPE=' + config

src/EWBinOp.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ namespace x {
113113
};
114114
} // namespace x
115115

116-
tensor_i::ptr_type EWBinOp::op(EWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b)
116+
tensor_i::ptr_type EWBinOp::op(EWBinOpId op, x::DPTensorBaseX::ptr_type a, py::object b)
117117
{
118-
return TypeDispatch<x::EWBinOp>(a, b, op);
118+
auto bb = x::mk_tx(b);
119+
return TypeDispatch<x::EWBinOp>(a, bb, op);
119120
}

src/IEWBinOp.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ namespace x {
6262
};
6363
} // namespace x
6464

65-
void IEWBinOp::op(IEWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b)
65+
void IEWBinOp::op(IEWBinOpId op, x::DPTensorBaseX::ptr_type a, py::object b)
6666
{
67-
TypeDispatch<x::IEWBinOp>(a, b, op);
67+
auto bb = x::mk_tx(b);
68+
TypeDispatch<x::IEWBinOp>(a, bb, op);
6869
}

src/include/ddptensor/Operations.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ struct Creator
1515

1616
struct IEWBinOp
1717
{
18-
static void op(IEWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b);
18+
static void op(IEWBinOpId op, x::DPTensorBaseX::ptr_type a, py::object b);
1919
};
2020

2121
struct EWBinOp
2222
{
23-
static tensor_i::ptr_type op(EWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b);
23+
static tensor_i::ptr_type op(EWBinOpId op, x::DPTensorBaseX::ptr_type a, py::object b);
2424
};
2525

2626
struct EWUnyOp

src/include/ddptensor/x.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ namespace x
8585
_xarray = org;
8686
}
8787

88+
DPTensorX(const T & v)
89+
: _owner(theTransceiver->rank()),
90+
_slice(shape_type{1}),
91+
_lslice({xt::newaxis()}), //to_xt(_slice.slice())),
92+
_xarray(std::make_shared<xt::xarray<T>>(1)),
93+
_replica(v)
94+
{
95+
*_xarray = v;
96+
}
97+
8898
virtual std::string __repr__() const
8999
{
90100
auto v = xt::strided_view(xarray(), lslice());
@@ -235,6 +245,11 @@ namespace x
235245
{
236246
public:
237247

248+
static DPTensorBaseX::ptr_type mk_tx(py::object & o)
249+
{
250+
return std::make_shared<DPTensorX<T>>(o.cast<T>());
251+
}
252+
238253
template<typename ...Ts>
239254
static DPTensorBaseX::ptr_type mk_tx(Ts&&... args)
240255
{
@@ -254,4 +269,16 @@ namespace x
254269
}
255270
};
256271

272+
static DPTensorBaseX::ptr_type mk_tx(py::object & b)
273+
{
274+
if(py::isinstance<x::DPTensorBaseX::ptr_type>(b)) {
275+
return b.cast<x::DPTensorBaseX::ptr_type>();
276+
} else if(py::isinstance<py::float_>(b)) {
277+
return x::operatorx<double>::mk_tx(b);
278+
} else if(py::isinstance<py::int_>(b)) {
279+
return x::operatorx<int64_t>::mk_tx(b);
280+
}
281+
throw std::runtime_error("Invalid right operand to elementwise binary operation");
282+
};
283+
257284
} // namespace x

test/test_scalar.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import ddptensor as dt
2+
a = dt.ones((8,8), dt.float64)
3+
b = a + 1
4+
print(b)

0 commit comments

Comments
 (0)