@@ -9,48 +9,53 @@ namespace x {
99 template <typename T>
1010 struct Rand
1111 {
12- template <typename L, typename U>
13- static ptr_type op (const shape_type & shp, const L & lower, const U & upper)
12+ // template<typename L, typename U>
13+ static ptr_type op (const shape_type & shp, T lower, T upper)
1414 {
15- if constexpr (std::is_floating_point<T>::value) {
16- PVSlice pvslice (shp);
17- shape_type shape (std::move (pvslice.shape_of_rank ()));
18- return operatorx<T>::mk_tx (std::move (pvslice), std::move (xt::random::rand (std::move (shape), to_native<T>(lower), to_native<T>(upper))));
19- }
15+ PVSlice pvslice (shp);
16+ shape_type shape (std::move (pvslice.shape_of_rank ()));
17+ return operatorx<T>::mk_tx (std::move (pvslice), std::move (xt::random::rand (std::move (shape), lower, upper)));
2018 }
2119 };
2220}
2321
22+ template <typename T>
2423struct DeferredRandomOp : public Deferred
2524{
2625 shape_type _shape;
27- py::object _lower, _upper;
28- DTypeId _dtype;
26+ T _lower, _upper;
2927
30- DeferredRandomOp (DTypeId dtype, const shape_type & shape, const py::object & lower, const py::object & upper)
31- : _shape(shape), _lower(lower), _upper(upper), _dtype(dtype)
28+ DeferredRandomOp (const shape_type & shape, T lower, T upper)
29+ : _shape(shape), _lower(lower), _upper(upper)
3230 {}
3331
3432 void run ()
3533 {
36- switch (_dtype) {
37- case FLOAT64:
38- set_value (x::Rand<double >::op (_shape, _lower, _upper));
39- return ;
40- case FLOAT32:
41- set_value (x::Rand<float >::op (_shape, _lower, _upper));
42- return ;
43- }
44- throw std::runtime_error (" rand: dtype must be a floating point type" );
34+ set_value (x::Rand<T>::op (_shape, _lower, _upper));
4535 }
4636};
4737
4838Random::future_type Random::rand (DTypeId dtype, const shape_type & shape, const py::object & lower, const py::object & upper)
4939{
50- return defer<DeferredRandomOp>(dtype, shape, lower, upper);
40+ switch (dtype) {
41+ case FLOAT64: {
42+ double lo = x::to_native<double >(lower);
43+ double up = x::to_native<double >(upper);
44+ return defer ([shape, lo, up](){return x::Rand<double >::op (shape, lo, up);});
45+ // return defer<DeferredRandomOp<double>>(shape, x::to_native<double>(lower), x::to_native<double>(upper));
46+ }
47+ case FLOAT32: {
48+ float lo = x::to_native<float >(lower);
49+ float up = x::to_native<float >(upper);
50+ return defer ([shape, lo, up](){return x::Rand<float >::op (shape, lo, up);});
51+ // return defer<DeferredRandomOp<float>>(shape, x::to_native<double>(lower), x::to_native<double>(upper));
52+ }
53+ default :
54+ throw std::runtime_error (" rand: dtype must be a floating point type" );
55+ }
5156}
5257
5358void Random::seed (uint64_t s)
5459{
55- xt::random::seed (s);
60+ defer ([s](){ xt::random::seed (s); return tensor_i::ptr_type ();} );
5661}
0 commit comments