33
44namespace x {
55
6- template <typename T>
76 class EWBinOp
87 {
98 public:
109 using ptr_type = DPTensorBaseX::ptr_type;
1110
1211#pragma GCC diagnostic ignored "-Wswitch"
13-
14- template <typename A, typename B, typename U = T, std::enable_if_t <std::is_floating_point<U>::value, bool > = true >
15- static ptr_type integral_op (EWBinOpId iop, const DPTensorX<T> & tx, A && a, B && b)
16- {
17- throw std::runtime_error (" Illegal or unknown inplace elementwise binary operation" );
18- }
19-
20- template <typename A, typename B, typename U = T, std::enable_if_t <std::is_integral<U>::value, bool > = true >
21- static ptr_type integral_op (EWBinOpId iop, const DPTensorX<T> & tx, A && a, B && b)
22- {
23- switch (iop) {
24- case __AND__:
25- case BITWISE_AND:
26- return operatorx<T>::mk_tx_ (tx, a & b);
27- case __RAND__:
28- return operatorx<T>::mk_tx_ (tx, b & a);
29- case __LSHIFT__:
30- case BITWISE_LEFT_SHIFT:
31- return operatorx<T>::mk_tx_ (tx, a << b);
32- case __MOD__:
33- case REMAINDER:
34- return operatorx<T>::mk_tx_ (tx, a % b);
35- case __OR__:
36- case BITWISE_OR:
37- return operatorx<T>::mk_tx_ (tx, a | b);
38- case __ROR__:
39- return operatorx<T>::mk_tx_ (tx, b | a);
40- case __RSHIFT__:
41- case BITWISE_RIGHT_SHIFT:
42- return operatorx<T>::mk_tx_ (tx, a >> b);
43- case __XOR__:
44- case BITWISE_XOR:
45- return operatorx<T>::mk_tx_ (tx, a ^ b);
46- case __RXOR__:
47- return operatorx<T>::mk_tx_ (tx, b ^ a);
48- case __RLSHIFT__:
49- return operatorx<T>::mk_tx_ (tx, b << a);
50- case __RMOD__:
51- return operatorx<T>::mk_tx_ (tx, b % a);
52- case __RRSHIFT__:
53- return operatorx<T>::mk_tx_ (tx, b >> a);
54- default :
55- throw std::runtime_error (" Unknown elementwise binary operation" );
56- }
57- }
58-
59- static ptr_type op (EWBinOpId bop, const ptr_type & a_ptr, const ptr_type & b_ptr)
12+ template <typename A, typename B>
13+ static ptr_type op (EWBinOpId bop, const std::shared_ptr<DPTensorX<A>> & a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr)
6014 {
61- const auto _a = dynamic_cast <DPTensorX<T>*>(a_ptr.get ());
62- const auto _b = dynamic_cast <DPTensorX<T>*>(b_ptr.get ());
63- if (!_a || !_b)
64- throw std::runtime_error (" Invalid array object: could not dynamically cast" );
65- const auto & a = xt::strided_view (_a->xarray (), _a->lslice ());
66- const auto & b = xt::strided_view (_b->xarray (), _b->lslice ());
15+ const auto & a = xt::strided_view (a_ptr->xarray (), a_ptr->lslice ());
16+ const auto & b = xt::strided_view (b_ptr->xarray (), b_ptr->lslice ());
6717
6818 switch (bop) {
6919 case __ADD__:
7020 case ADD:
71- return operatorx<T >::mk_tx_ (*_a , a + b);
21+ return operatorx<A >::mk_tx_ (a_ptr , a + b);
7222 case __RADD__:
73- return operatorx<T >::mk_tx_ (*_a , b + a);
23+ return operatorx<A >::mk_tx_ (a_ptr , b + a);
7424 case ATAN2:
75- return operatorx<T >::mk_tx_ (*_a , xt::atan2 (a, b));
25+ return operatorx<A >::mk_tx_ (a_ptr , xt::atan2 (a, b));
7626 case __EQ__:
7727 case EQUAL:
78- return operatorx<T >::mk_tx_ (*_a , xt::equal (a, b));
28+ return operatorx<A >::mk_tx_ (a_ptr , xt::equal (a, b));
7929 case __FLOORDIV__:
8030 case FLOOR_DIVIDE:
81- return operatorx<T >::mk_tx_ (*_a , xt::floor (a / b));
31+ return operatorx<A >::mk_tx_ (a_ptr , xt::floor (a / b));
8232 case __GE__:
8333 case GREATER_EQUAL:
84- return operatorx<T >::mk_tx_ (*_a , a >= b);
34+ return operatorx<A >::mk_tx_ (a_ptr , a >= b);
8535 case __GT__:
8636 case GREATER:
87- return operatorx<T >::mk_tx_ (*_a , a > b);
37+ return operatorx<A >::mk_tx_ (a_ptr , a > b);
8838 case __LE__:
8939 case LESS_EQUAL:
90- return operatorx<T >::mk_tx_ (*_a , a <= b);
40+ return operatorx<A >::mk_tx_ (a_ptr , a <= b);
9141 case __LT__:
9242 case LESS:
93- return operatorx<T >::mk_tx_ (*_a , a < b);
43+ return operatorx<A >::mk_tx_ (a_ptr , a < b);
9444 case __MUL__:
9545 case MULTIPLY:
96- return operatorx<T >::mk_tx_ (*_a , a * b);
46+ return operatorx<A >::mk_tx_ (a_ptr , a * b);
9747 case __RMUL__:
98- return operatorx<T >::mk_tx_ (*_a , b * a);
48+ return operatorx<A >::mk_tx_ (a_ptr , b * a);
9949 case __NE__:
10050 case NOT_EQUAL:
101- return operatorx<T >::mk_tx_ (*_a , xt::not_equal (a, b));
51+ return operatorx<A >::mk_tx_ (a_ptr , xt::not_equal (a, b));
10252 case __SUB__:
10353 case SUBTRACT:
104- return operatorx<T >::mk_tx_ (*_a , a - b);
54+ return operatorx<A >::mk_tx_ (a_ptr , a - b);
10555 case __TRUEDIV__:
10656 case DIVIDE:
107- return operatorx<T >::mk_tx_ (*_a , a / b);
57+ return operatorx<A >::mk_tx_ (a_ptr , a / b);
10858 case __RFLOORDIV__:
109- return operatorx<T >::mk_tx_ (*_a , xt::floor (b / a));
59+ return operatorx<A >::mk_tx_ (a_ptr , xt::floor (b / a));
11060 case __RSUB__:
111- return operatorx<T >::mk_tx_ (*_a , b - a);
61+ return operatorx<A >::mk_tx_ (a_ptr , b - a);
11262 case __RTRUEDIV__:
113- return operatorx<T >::mk_tx_ (*_a , b / a);
63+ return operatorx<A >::mk_tx_ (a_ptr , b / a);
11464 case __MATMUL__:
11565 case __POW__:
11666 case POW:
@@ -122,15 +72,48 @@ namespace x {
12272 // FIXME
12373 throw std::runtime_error (" Binary operation not implemented" );
12474 }
125- return integral_op (bop, *_a, a, b);
75+ if constexpr (std::is_integral<A>::value && std::is_integral<B>::value) {
76+ switch (bop) {
77+ case __AND__:
78+ case BITWISE_AND:
79+ return operatorx<A>::mk_tx_ (a_ptr, a & b);
80+ case __RAND__:
81+ return operatorx<A>::mk_tx_ (a_ptr, b & a);
82+ case __LSHIFT__:
83+ case BITWISE_LEFT_SHIFT:
84+ return operatorx<A>::mk_tx_ (a_ptr, a << b);
85+ case __MOD__:
86+ case REMAINDER:
87+ return operatorx<A>::mk_tx_ (a_ptr, a % b);
88+ case __OR__:
89+ case BITWISE_OR:
90+ return operatorx<A>::mk_tx_ (a_ptr, a | b);
91+ case __ROR__:
92+ return operatorx<A>::mk_tx_ (a_ptr, b | a);
93+ case __RSHIFT__:
94+ case BITWISE_RIGHT_SHIFT:
95+ return operatorx<A>::mk_tx_ (a_ptr, a >> b);
96+ case __XOR__:
97+ case BITWISE_XOR:
98+ return operatorx<A>::mk_tx_ (a_ptr, a ^ b);
99+ case __RXOR__:
100+ return operatorx<A>::mk_tx_ (a_ptr, b ^ a);
101+ case __RLSHIFT__:
102+ return operatorx<A>::mk_tx_ (a_ptr, b << a);
103+ case __RMOD__:
104+ return operatorx<A>::mk_tx_ (a_ptr, b % a);
105+ case __RRSHIFT__:
106+ return operatorx<A>::mk_tx_ (a_ptr, b >> a);
107+ }
108+ }
109+ throw std::runtime_error (" Unknown/invalid elementwise binary operation" );
126110 }
127-
128111#pragma GCC diagnostic pop
129112
130113 };
131114} // namespace x
132115
133116tensor_i::ptr_type EWBinOp::op (EWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b)
134117{
135- return TypeDispatch <x::EWBinOp>(a-> dtype (), op, a, b );
118+ return TypeDispatch2 <x::EWBinOp>(a, b, op );
136119}
0 commit comments