Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 5b07b3e

Browse files
committed
adding full creator, supporting nd-tensors
1 parent ea8bcc6 commit 5b07b3e

File tree

6 files changed

+295
-54
lines changed

6 files changed

+295
-54
lines changed

ddptensor/ddptensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _inplace(self, t):
5151
)
5252

5353
def __getitem__(self, key):
54-
return dtensor(self._t.__getitem__(key if isinstance(key, list) else [key,]))
54+
return dtensor(self._t.__getitem__(key if isinstance(key, tuple) else (key,)))
5555

5656
def __setitem__(self, key, value):
57-
self._t.__setitem__(key if isinstance(key, list) else [key,], value._t) # if isinstance(value, dtensor) else value)
57+
self._t.__setitem__(key if isinstance(key, tuple) else (key,), value._t) # if isinstance(value, dtensor) else value)

src/Creator.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,44 @@ struct DeferredFull : public Deferred
128128
// set_value(std::move(TypeDispatch<x::Creator>(_dtype, op, _shape, _val)));
129129
}
130130

131-
// FIXME mlir
131+
template<typename T>
132+
struct ValAndDType
133+
{
134+
static ::mlir::Value op(::mlir::OpBuilder & builder, ::mlir::Location loc, const PyScalar & val, ::imex::ptensor::DType & dtyp)
135+
{
136+
dtyp = jit::PT_DTYPE<T>::value;
137+
138+
if constexpr (std::is_floating_point_v<T>) return ::imex::createFloat<sizeof(T)*8>(loc, builder, val._float);
139+
else if constexpr (std::is_same_v<bool, T>) return ::imex::createInt<1>(loc, builder, val._int);
140+
else if constexpr (std::is_integral_v<T>) return ::imex::createInt<sizeof(T)*8>(loc, builder, val._int);
141+
assert("Unsupported dtype in dispatch");
142+
return {};
143+
};
144+
};
145+
146+
bool generate_mlir(::mlir::OpBuilder & builder, ::mlir::Location loc, jit::DepManager & dm) override
147+
{
148+
::mlir::SmallVector<::mlir::Value> shp(_shape.size());
149+
for(auto i=0; i<_shape.size(); ++i) {
150+
shp[i] = ::imex::createIndex(loc, builder, _shape[i]);
151+
}
152+
153+
::imex::ptensor::DType dtyp;
154+
::mlir::Value val = dispatch<ValAndDType>(_dtype, builder, loc, _val, dtyp);
155+
156+
auto dmy = ::imex::createInt<1>(loc, builder, 0);
157+
auto team = ::imex::createIndex(loc, builder, reinterpret_cast<uint64_t>(getTransceiver()));
158+
159+
dm.addVal(this->guid(),
160+
builder.create<::imex::ptensor::CreateOp>(loc, shp, dtyp, val, dmy, team),
161+
[this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
162+
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
163+
assert(rank == this->_shape.size());
164+
this->set_value(std::move(mk_tnsr(_dtype, rank, allocated, aligned, offset, sizes, strides,
165+
gs_allocated, gs_aligned, lo_allocated, lo_aligned)));
166+
});
167+
return false;
168+
}
132169

133170
FactoryId factory() const
134171
{

src/DDPTensorImpl.cpp

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
// SPDX-License-Identifier: BSD-3-Clause
2+
3+
// Concrete implementation of tensor_i.
4+
// Interfaces are based on shared_ptr<tensor_i>.
5+
6+
#include <ddptensor/DDPTensorImpl.hpp>
7+
#include <ddptensor/CppTypes.hpp>
8+
9+
#include <algorithm>
10+
11+
12+
DDPTensorImpl::DDPTensorImpl(DTypeId dtype, uint64_t ndims,
13+
void * allocated, void * aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
14+
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned,
15+
rank_type owner)
16+
: _owner(owner),
17+
_allocated(allocated),
18+
_aligned(aligned),
19+
_sizes(new intptr_t[ndims]),
20+
_strides(new intptr_t[ndims]),
21+
_gs_allocated(gs_allocated),
22+
_gs_aligned(gs_aligned),
23+
_lo_allocated(lo_allocated),
24+
_lo_aligned(lo_aligned),
25+
_offset(offset),
26+
_ndims(ndims),
27+
_dtype(dtype)
28+
{
29+
memcpy(_sizes, sizes, ndims*sizeof(*_sizes));
30+
memcpy(_strides, strides, ndims*sizeof(*_strides));
31+
}
32+
33+
DDPTensorImpl::DDPTensorImpl(DTypeId dtype, const shape_type & shp, rank_type owner)
34+
: _owner(owner),
35+
_ndims(shp.size()),
36+
_dtype(dtype)
37+
{
38+
alloc();
39+
40+
intptr_t stride = 1;
41+
auto ndims = shp.size();
42+
assert(ndims <= 1);
43+
for(auto i=0; i<ndims; ++i) {
44+
_sizes[i] = shp[i];
45+
_strides[ndims-i-1] = stride;
46+
stride *= shp[i];
47+
}
48+
}
49+
50+
DDPTensorImpl::ptr_type DDPTensorImpl::clone(bool copy)
51+
{
52+
// FIXME memory leak
53+
auto nd = ndims();
54+
auto sz = size();
55+
auto esz = sizeof_dtype(dtype());
56+
auto bsz = sz * esz;
57+
auto allocated = new (std::align_val_t(esz)) char[bsz];
58+
auto aligned = allocated;
59+
if(copy) memcpy(aligned, _aligned, bsz);
60+
// FIXME jit returns private mem
61+
// memcpy(gs_aligned, _gs_aligned, nd*sizeof(*gs_aligned));
62+
// auto gs_allocated = new uint64_t[nd];
63+
// auto gs_aligned = gs_allocated;
64+
auto gs_allocated = _gs_allocated;
65+
auto gs_aligned = _gs_aligned;
66+
auto lo_allocated = new uint64_t[nd];
67+
auto lo_aligned = lo_allocated;
68+
memcpy(lo_aligned, _lo_aligned, nd*sizeof(*lo_aligned));
69+
70+
// strides and sizes are allocated/copied in constructor
71+
return std::make_shared<DDPTensorImpl>(dtype(), nd, allocated, aligned, _offset, _sizes, _strides,
72+
gs_allocated, gs_aligned, lo_allocated, lo_aligned, owner());
73+
}
74+
75+
void DDPTensorImpl::alloc()
76+
{
77+
auto esz = sizeof_dtype(_dtype);
78+
_allocated = new (std::align_val_t(esz)) char[esz*size()];
79+
_aligned = _allocated;
80+
auto nds = ndims();
81+
_sizes = new intptr_t[nds];
82+
_strides = new intptr_t[nds];
83+
_offset = 0;
84+
}
85+
86+
void * DDPTensorImpl::data()
87+
{
88+
void * ret;
89+
dispatch(_dtype, _aligned, [this, &ret](auto * ptr) { ret = ptr + this->_offset; });
90+
return ret;
91+
}
92+
93+
std::string DDPTensorImpl::__repr__() const
94+
{
95+
const auto nd = ndims();
96+
std::ostringstream oss;
97+
oss << "ddptensor{gs=(";
98+
for(auto i=0; i<nd; ++i) oss << _gs_aligned[i] << ", ";
99+
oss << "), loff=(";
100+
for(auto i=0; i<nd; ++i) oss << _lo_aligned[i] << ", ";
101+
oss << "), lsz=(";
102+
for(auto i=0; i<nd; ++i) oss << _sizes[i] << ", ";
103+
oss << "), str=(";
104+
for(auto i=0; i<nd; ++i) oss << _strides[i] << ", ";
105+
oss << "), p=" << _allocated << ", poff=" << _offset << "}\n";
106+
107+
dispatch(_dtype, _aligned, [this, nd, &oss](auto * ptr) {
108+
auto cptr = ptr + this->_offset;
109+
printit(oss, 0, cptr);
110+
});
111+
return oss.str();
112+
}
113+
114+
bool DDPTensorImpl::__bool__() const
115+
{
116+
if(! is_replicated())
117+
throw(std::runtime_error("Cast to scalar bool: tensor is not replicated"));
118+
119+
bool res;
120+
dispatch(_dtype, _aligned, [this, &res](auto * ptr) { res = static_cast<bool>(ptr[this->_offset]); });
121+
return res;
122+
}
123+
124+
double DDPTensorImpl::__float__() const
125+
{
126+
if(! is_replicated())
127+
throw(std::runtime_error("Cast to scalar float: tensor is not replicated"));
128+
129+
double res;
130+
dispatch(_dtype, _aligned, [this, &res](auto * ptr) { res = static_cast<double>(ptr[this->_offset]); });
131+
return res;
132+
}
133+
134+
int64_t DDPTensorImpl::__int__() const
135+
{
136+
if(! is_replicated())
137+
throw(std::runtime_error("Cast to scalar int: tensor is not replicated"));
138+
139+
float res;
140+
dispatch(_dtype, _aligned, [this, &res](auto * ptr) { res = static_cast<float>(ptr[this->_offset]); });
141+
return res;
142+
}
143+
144+
void DDPTensorImpl::bufferize(const NDSlice & slc, Buffer & buff) const
145+
{
146+
// FIXME slices/strides
147+
#if 0
148+
if(slc.size() <= 0) return;
149+
NDSlice lslice = NDSlice(slice().tile_shape()).slice(slc);
150+
#endif
151+
assert(_strides[0] == 1);
152+
auto pos = buff.size();
153+
auto sz = size()*item_size();
154+
buff.resize(pos + sz);
155+
void * out = buff.data() + pos;
156+
dispatch(_dtype, _aligned, [this, sz, out](auto * ptr) { memcpy(out, ptr + this->_offset, sz); });
157+
}
158+
159+
void DDPTensorImpl::add_to_args(std::vector<void*> & args, int ndims)
160+
{
161+
assert(ndims == this->ndims() || (ndims == 0 && this->ndims() == 1));
162+
// global shape first
163+
intptr_t * buff = new intptr_t[dtensor_sz(1)];
164+
buff[0] = reinterpret_cast<intptr_t>(_gs_allocated);
165+
buff[1] = reinterpret_cast<intptr_t>(_gs_aligned);
166+
buff[2] = 0;
167+
buff[3] = ndims;
168+
buff[4] = 1;
169+
args.push_back(buff);
170+
assert(5 == memref_sz(1));
171+
// local tensor
172+
buff = new intptr_t[dtensor_sz(ndims)];
173+
buff[0] = reinterpret_cast<intptr_t>(_allocated);
174+
buff[1] = reinterpret_cast<intptr_t>(_aligned);
175+
buff[2] = static_cast<intptr_t>(_offset);
176+
memcpy(buff+3, _sizes, ndims*sizeof(intptr_t));
177+
memcpy(buff+3+ndims, _strides, ndims*sizeof(intptr_t));
178+
args.push_back(buff);
179+
// local offsets
180+
buff = new intptr_t[dtensor_sz(1)];
181+
buff[0] = reinterpret_cast<intptr_t>(_lo_allocated);
182+
buff[1] = reinterpret_cast<intptr_t>(_lo_aligned);
183+
buff[2] = 0;
184+
buff[3] = ndims;
185+
buff[4] = 1;
186+
args.push_back(buff);
187+
// finally the team
188+
args.push_back(reinterpret_cast<void*>(1));
189+
}

src/include/ddptensor/DDPTensorImpl.hpp

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <cstring>
1414
#include <type_traits>
1515
#include <memory>
16+
#include <sstream>
1617

1718

1819
class DDPTensorImpl : public tensor_i
@@ -97,8 +98,11 @@ class DDPTensorImpl : public tensor_i
9798

9899
virtual uint64_t size() const
99100
{
100-
assert(ndims() == 1);
101-
return *_sizes;
101+
switch(ndims()) {
102+
case 0 : return 1;
103+
case 1 : return *_sizes;
104+
default: return std::accumulate(_sizes, _sizes+ndims(), 1, std::multiplies<intptr_t>());
105+
}
102106
}
103107

104108
friend struct Service;
@@ -140,6 +144,30 @@ class DDPTensorImpl : public tensor_i
140144
virtual void bufferize(const NDSlice & slc, Buffer & buff) const;
141145

142146
virtual void add_to_args(std::vector<void*> & args, int ndims);
147+
148+
template<typename T>
149+
void printit(std::ostringstream & oss, uint64_t d, T * cptr) const
150+
{
151+
auto stride = _strides[d];
152+
auto sz = _sizes[d];
153+
if(d==ndims()-1) {
154+
oss << "[";
155+
for(auto i=0; i<sz; ++i) {
156+
oss << cptr[i*stride];
157+
if(i<sz-1) oss << " ";
158+
}
159+
oss << "]";
160+
} else {
161+
oss << "[";
162+
for(auto i=0; i<sz; ++i) {
163+
if(i) for(auto x=0; x<=d; ++x) oss << " ";
164+
printit(oss, d+1, cptr);
165+
if(i<sz-1) oss << "\n";
166+
cptr += stride;
167+
}
168+
oss << "]";
169+
}
170+
}
143171
};
144172

145173
template<typename ...Ts>

0 commit comments

Comments
 (0)