Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit b303f2f

Browse files
committed
correctly handling 0d tensors and 1-element tensors and various element types
1 parent 5f57e98 commit b303f2f

File tree

9 files changed

+152
-54
lines changed

9 files changed

+152
-54
lines changed

ddptensor/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,21 @@ def to_numpy(a):
6262

6363
for func in api.api_categories["Creator"]:
6464
FUNC = func.upper()
65-
if func in ["empty", "ones", "zeros",]:
65+
if func == "full":
6666
exec(
67-
f"{func} = lambda shape, dtype: dtensor(_cdt.Creator.create_from_shape(_cdt.{FUNC}, shape, dtype))"
67+
f"{func} = lambda shape, val, dtype: dtensor(_cdt.Creator.full(shape, val, dtype))"
6868
)
69-
elif func == "full":
69+
elif func == "empty":
7070
exec(
71-
f"{func} = lambda shape, val, dtype: dtensor(_cdt.Creator.full(shape, val, dtype))"
71+
f"{func} = lambda shape, dtype: dtensor(_cdt.Creator.full(shape, None, dtype))"
72+
)
73+
elif func == "ones":
74+
exec(
75+
f"{func} = lambda shape, dtype: dtensor(_cdt.Creator.full(shape, 1, dtype))"
76+
)
77+
elif func == "zeros":
78+
exec(
79+
f"{func} = lambda shape, dtype: dtensor(_cdt.Creator.full(shape, 0, dtype))"
7280
)
7381
elif func == "arange":
7482
exec(

src/DDPTensorImpl.cpp

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ DDPTensorImpl::DDPTensorImpl(DTypeId dtype, uint64_t ndims,
1616
: _owner(owner),
1717
_allocated(allocated),
1818
_aligned(aligned),
19-
_sizes(new intptr_t[ndims]),
20-
_strides(new intptr_t[ndims]),
2119
_gs_allocated(gs_allocated),
2220
_gs_aligned(gs_aligned),
2321
_lo_allocated(lo_allocated),
@@ -26,8 +24,15 @@ DDPTensorImpl::DDPTensorImpl(DTypeId dtype, uint64_t ndims,
2624
_ndims(ndims),
2725
_dtype(dtype)
2826
{
29-
memcpy(_sizes, sizes, ndims*sizeof(*_sizes));
30-
memcpy(_strides, strides, ndims*sizeof(*_strides));
27+
if(ndims > 0) {
28+
_sizes = new intptr_t[ndims];
29+
_strides = new intptr_t[ndims];
30+
memcpy(_sizes, sizes, ndims*sizeof(*_sizes));
31+
memcpy(_strides, strides, ndims*sizeof(*_strides));
32+
} else {
33+
_owner = REPLICATED;
34+
assert(_aligned);
35+
}
3136
}
3237

3338
DDPTensorImpl::DDPTensorImpl(DTypeId dtype, const shape_type & shp, rank_type owner)
@@ -72,15 +77,17 @@ DDPTensorImpl::ptr_type DDPTensorImpl::clone(bool copy)
7277
gs_allocated, gs_aligned, lo_allocated, lo_aligned, owner());
7378
}
7479

75-
void DDPTensorImpl::alloc()
80+
void DDPTensorImpl::alloc(bool all)
7681
{
7782
auto esz = sizeof_dtype(_dtype);
78-
_allocated = new (std::align_val_t(esz)) char[esz*size()];
83+
_allocated = new (std::align_val_t(esz)) char[esz*local_size()];
7984
_aligned = _allocated;
80-
auto nds = ndims();
81-
_sizes = new intptr_t[nds];
82-
_strides = new intptr_t[nds];
8385
_offset = 0;
86+
if(all) {
87+
auto nds = ndims();
88+
_sizes = new intptr_t[nds];
89+
_strides = new intptr_t[nds];
90+
}
8491
}
8592

8693
void * DDPTensorImpl::data()
@@ -106,8 +113,11 @@ std::string DDPTensorImpl::__repr__() const
106113

107114
dispatch(_dtype, _aligned, [this, nd, &oss](auto * ptr) {
108115
auto cptr = ptr + this->_offset;
109-
if(nd>0) printit(oss, 0, cptr);
110-
else oss << *cptr;
116+
if(nd>0) {
117+
printit(oss, 0, cptr);
118+
} else {
119+
oss << *cptr;
120+
}
111121
});
112122
return oss.str();
113123
}
@@ -189,3 +199,26 @@ void DDPTensorImpl::add_to_args(std::vector<void*> & args, int ndims)
189199
buff[4] = 1;
190200
args.push_back(buff);
191201
}
202+
203+
void DDPTensorImpl::replicate()
204+
{
205+
if(is_replicated()) return;
206+
auto gsz = size();
207+
auto lsz = local_size();
208+
if(gsz > 1) throw(std::runtime_error("Replication implemented for single-element tensors only."));
209+
if(lsz != gsz) {
210+
assert(lsz == 0);
211+
auto nd = ndims();
212+
for(auto i=0; i<nd; ++i) {
213+
_sizes[i] = _strides[i] = 1;
214+
}
215+
_sizes[nd-1] = gsz;
216+
}
217+
dispatch(_dtype, _aligned, [this, lsz, gsz](auto * ptr) {
218+
auto tmp = ptr[this->_offset];
219+
if(lsz != gsz) ptr[this->_offset] = 0;
220+
getTransceiver()->reduce_all(&ptr[this->_offset], this->_dtype, 1, SUM);
221+
assert(lsz != gsz || tmp == ptr[this->_offset]);
222+
});
223+
set_owner(REPLICATED);
224+
}

src/Service.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,20 @@ struct DeferredService : public Deferred
5252

5353
void run()
5454
{
55-
#if 0
5655
switch(_op) {
5756
case REPLICATE: {
5857
const auto a = std::move(Registry::get(_a).get());
59-
set_value(std::move(TypeDispatch<x::Service>(a)));
58+
auto ddpt = dynamic_cast<DDPTensorImpl*>(a.get());
59+
assert(ddpt);
60+
ddpt->replicate();
61+
set_value(a);
6062
break;
6163
}
62-
case DROP:
63-
Registry::del(_a);
64+
case RUN:
6465
break;
6566
default:
66-
throw(std::runtime_error("Unkown Service operation requested."));
67+
throw(std::runtime_error("Unkown Service operation requested."));
6768
}
68-
#endif
6969
}
7070

7171
bool generate_mlir(::mlir::OpBuilder & builder, ::mlir::Location loc, jit::DepManager & dm) override
@@ -76,6 +76,7 @@ struct DeferredService : public Deferred
7676
// FIXME create delete op and return it
7777
break;
7878
case RUN:
79+
case REPLICATE:
7980
return true;
8081
default:
8182
throw(std::runtime_error("Unkown Service operation requested."));

src/idtr.cpp

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ void _idtr_local_shape(id_t guid, void * alloced, void * aligned, intptr_t offse
9494
{
9595
idtr_local_shape(guid, mr_to_ptr<uint64_t>(aligned, offset), nD);
9696
}
97+
} // extern "C"
9798

9899
// convert id of our reduction op to id of imex::ptensor reduction op
99100
static ReduceOpId mlir2ddpt(const ::imex::ptensor::ReduceOpId rop)
@@ -118,17 +119,61 @@ static ReduceOpId mlir2ddpt(const ::imex::ptensor::ReduceOpId rop)
118119
}
119120
}
120121

122+
static DTypeId mlir2ddpt(const ::imex::ptensor::DType dt)
123+
{
124+
switch(dt) {
125+
case ::imex::ptensor::DType::F64:
126+
return FLOAT64;
127+
break;
128+
case ::imex::ptensor::DType::I64:
129+
return INT64;
130+
break;
131+
case ::imex::ptensor::DType::U64:
132+
return UINT64;
133+
break;
134+
case ::imex::ptensor::DType::F32:
135+
return FLOAT32;
136+
break;
137+
case ::imex::ptensor::DType::I32:
138+
return INT32;
139+
break;
140+
case ::imex::ptensor::DType::U32:
141+
return UINT32;
142+
break;
143+
case ::imex::ptensor::DType::I16:
144+
return INT16;
145+
break;
146+
case ::imex::ptensor::DType::U16:
147+
return UINT16;
148+
break;
149+
case ::imex::ptensor::DType::I8:
150+
return INT8;
151+
break;
152+
case ::imex::ptensor::DType::U8:
153+
return UINT8;
154+
break;
155+
case ::imex::ptensor::DType::I1:
156+
return BOOL;
157+
break;
158+
default:
159+
throw std::runtime_error("unknown dtype");
160+
};
161+
}
162+
163+
extern "C" {
121164
// Elementwise inplace allreduce
122-
void idtr_reduce_all(void * inout, DTypeId dtype, uint64_t N, int op)
165+
void idtr_reduce_all(void * inout, DTypeId dtype, uint64_t N, ReduceOpId op)
123166
{
124-
getTransceiver()->reduce_all(inout, dtype, N, mlir2ddpt(static_cast<imex::ptensor::ReduceOpId>(op)));
167+
getTransceiver()->reduce_all(inout, dtype, N, op);
125168
}
126169

127170
// FIXME hard-coded for contiguous layout
128-
void _idtr_reduce_all(uint64_t rank, void * data, int64_t * sizes, int64_t * strides, DTypeId dtype, int op)
171+
void _idtr_reduce_all(uint64_t rank, void * data, int64_t * sizes, int64_t * strides, int dtype, int op)
129172
{
130173
assert(rank == 0 || strides[rank-1] == 1);
131-
idtr_reduce_all(data, dtype, rank ? rank : 1, op);
174+
idtr_reduce_all(data,
175+
mlir2ddpt(static_cast<::imex::ptensor::DType>(dtype)),
176+
rank ? rank : 1,
177+
mlir2ddpt(static_cast<imex::ptensor::ReduceOpId>(op)));
132178
}
133-
134179
} // extern "C"

src/include/ddptensor/CppTypes.hpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ using InputAdapter = bitsery::InputBufferAdapter<Buffer>;
2222
using Serializer = bitsery::Serializer<OutputAdapter>;
2323
using Deserializer = bitsery::Deserializer<InputAdapter>;
2424

25+
union PyScalar
26+
{
27+
int64_t _int;
28+
double _float;
29+
};
30+
2531
enum _RANKS: rank_type {
2632
NOOWNER = std::numeric_limits<rank_type>::max(),
2733
REPLICATED = std::numeric_limits<rank_type>::max() - 1,
@@ -42,17 +48,17 @@ template<> struct DTYPE<uint8_t> { constexpr static DTypeId value = UINT8; };
4248
template<> struct DTYPE<bool> { constexpr static DTypeId value = BOOL; };
4349

4450
template<DTypeId DT> struct TYPE {};
45-
template<> struct TYPE<FLOAT64> { using dtype = double; };
46-
template<> struct TYPE<FLOAT32> { using dtype = float; };
47-
template<> struct TYPE<INT64> { using dtype = int64_t; };
48-
template<> struct TYPE<INT32> { using dtype = int32_t; };
49-
template<> struct TYPE<INT16> { using dtype = int16_t; };
50-
template<> struct TYPE<INT8> { using dtype = int8_t; };
51-
template<> struct TYPE<UINT64> { using dtype = uint64_t; };
52-
template<> struct TYPE<UINT32> { using dtype = uint32_t; };
53-
template<> struct TYPE<UINT16> { using dtype = uint16_t; };
54-
template<> struct TYPE<UINT8> { using dtype = uint8_t; };
55-
template<> struct TYPE<BOOL> { using dtype = bool; };
51+
template<> struct TYPE<FLOAT64> { using dtype = double; static constexpr bool is_integral = false; static constexpr bool is_float = true; };
52+
template<> struct TYPE<FLOAT32> { using dtype = float; static constexpr bool is_integral = false; static constexpr bool is_float = true; };
53+
template<> struct TYPE<INT64> { using dtype = int64_t; static constexpr bool is_integral = true; static constexpr bool is_float = false; };
54+
template<> struct TYPE<INT32> { using dtype = int32_t; static constexpr bool is_integral = true; static constexpr bool is_float = false; };
55+
template<> struct TYPE<INT16> { using dtype = int16_t; static constexpr bool is_integral = true; static constexpr bool is_float = false; };
56+
template<> struct TYPE<INT8> { using dtype = int8_t; static constexpr bool is_integral = true; static constexpr bool is_float = false; };
57+
template<> struct TYPE<UINT64> { using dtype = uint64_t; static constexpr bool is_integral = true; static constexpr bool is_float = false; };
58+
template<> struct TYPE<UINT32> { using dtype = uint32_t; static constexpr bool is_integral = true; static constexpr bool is_float = false; };
59+
template<> struct TYPE<UINT16> { using dtype = uint16_t; static constexpr bool is_integral = true; static constexpr bool is_float = false; };
60+
template<> struct TYPE<UINT8> { using dtype = uint8_t; static constexpr bool is_integral = true; static constexpr bool is_float = false; };
61+
template<> struct TYPE<BOOL> { using dtype = bool; static constexpr bool is_integral = true; static constexpr bool is_float = false; };
5662

5763
static size_t sizeof_dtype(const DTypeId dt) {
5864
switch(dt) {

src/include/ddptensor/DDPTensorImpl.hpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class DDPTensorImpl : public tensor_i
6161

6262
DDPTensorImpl::ptr_type clone(bool copy = true);
6363

64-
void alloc();
64+
void alloc(bool all = true);
6565

6666
~DDPTensorImpl()
6767
{
@@ -100,11 +100,16 @@ class DDPTensorImpl : public tensor_i
100100
{
101101
switch(ndims()) {
102102
case 0 : return 1;
103-
case 1 : return *_sizes;
104-
default: return std::accumulate(_sizes, _sizes+ndims(), 1, std::multiplies<intptr_t>());
103+
case 1 : return *_gs_aligned;
104+
default: return std::accumulate(_gs_aligned, _gs_aligned+ndims(), 1, std::multiplies<intptr_t>());
105105
}
106106
}
107107

108+
uint64_t local_size() const
109+
{
110+
return ndims() == 0 ? 0 : std::accumulate(_sizes, _sizes+ndims(), 1, std::multiplies<intptr_t>());
111+
}
112+
108113
friend struct Service;
109114

110115
virtual bool __bool__() const;
@@ -113,7 +118,7 @@ class DDPTensorImpl : public tensor_i
113118

114119
virtual uint64_t __len__() const
115120
{
116-
return ndims() ? *_sizes : 0;
121+
return ndims() ? *_gs_aligned : 1;
117122
}
118123

119124
bool has_owner() const
@@ -167,6 +172,8 @@ class DDPTensorImpl : public tensor_i
167172
oss << "]";
168173
}
169174
}
175+
176+
void replicate();
170177
};
171178

172179
template<typename ...Ts>

src/include/ddptensor/PyTypes.hpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@ namespace py = pybind11;
1010

1111
template<typename T> py::object get_impl_dtype() { return get_impl_dtype(DTYPE<T>::value); };
1212

13-
union PyScalar
14-
{
15-
int64_t _int;
16-
double _float;
17-
};
18-
1913
inline PyScalar mk_scalar(const py::object & b, DTypeId dtype)
2014
{
2115
PyScalar s;

src/include/ddptensor/idtr.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ extern "C" {
2828
void idtr_local_shape(id_t guid, uint64_t * lshape, uint64_t N);
2929

3030
// Elementwise inplace allreduce
31-
void idtr_reduce_all(void * inout, DTypeId dtype, uint64_t N, int op);
31+
void idtr_reduce_all(void * inout, DTypeId dtype, uint64_t N, ReduceOpId op);
3232

3333
} // extern "C"

test/test_ewb.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import ddptensor as dt
22
import numpy as np
33

4+
mpi_dtypes = [dt.float64, dt.float32, dt.int64, dt.uint64, dt.int32, dt.uint32, dt.int8, dt.uint8]
5+
46
class TestEWB:
57
def test_add1(self):
6-
a = dt.ones([16,16], dtype=dt.float64)
7-
b = dt.ones([16,16], dtype=dt.float64)
8-
c = a + b
9-
r1 = dt.sum(c, [0,1])
10-
v = 16*16*2
11-
assert float(r1) == v
8+
for dtyp in mpi_dtypes:
9+
print(dtyp)
10+
a = dt.ones([6,6], dtype=dtyp)
11+
b = dt.ones([6,6], dtype=dtyp)
12+
c = a + b
13+
r1 = dt.sum(c, [0,1])
14+
v = 6*6*2
15+
assert float(r1) == v
1216

1317
def test_add2(self):
1418
a = dt.ones([16,16], dtype=dt.float64)

0 commit comments

Comments
 (0)