|
7 | 7 |
|
8 | 8 | #include <imex/Dialect/PTensor/IR/PTensorOps.h> |
9 | 9 | #include <imex/Dialect/Dist/IR/DistOps.h> |
10 | | -#include <imex/internal/PassUtils.h> |
| 10 | +#include <imex/Utils/PassUtils.h> |
11 | 11 | #include <mlir/IR/Builders.h> |
12 | 12 |
|
13 | 13 | #if 0 |
@@ -202,6 +202,38 @@ struct DeferredSetItem : public Deferred |
202 | 202 | //set_value(std::move(TypeDispatch<x::SetItem>(a, b, _slc, _b))); |
203 | 203 | } |
204 | 204 |
|
| 205 | + bool generate_mlir(::mlir::OpBuilder & builder, ::mlir::Location loc, jit::DepManager & dm) override |
| 206 | + { |
| 207 | + // get params and extract offsets/sizes/strides |
| 208 | + const auto dtype = this->dtype(); |
| 209 | + auto av = dm.getDependent(builder, _a); |
| 210 | + auto bv = dm.getDependent(builder, _b); |
| 211 | + auto & offs = _slc.offsets(); |
| 212 | + auto & sizes = _slc.sizes(); |
| 213 | + auto & strides = _slc.strides(); |
| 214 | + auto nd = offs.size(); |
| 215 | + // convert C++ slices into vectors of MLIR Values |
| 216 | + std::vector<::mlir::Value> offsV(nd); |
| 217 | + std::vector<::mlir::Value> sizesV(nd); |
| 218 | + std::vector<::mlir::Value> stridesV(nd); |
| 219 | + for(auto i = 0; i<nd; ++i) { |
| 220 | + offsV[i] = ::imex::createIndex(loc, builder, offs[i]); |
| 221 | + sizesV[i] = ::imex::createIndex(loc, builder, sizes[i]); |
| 222 | + stridesV[i] = ::imex::createIndex(loc, builder, strides[i]); |
| 223 | + } |
| 224 | + // insertsliceop has no return value, so we just craete the op... |
| 225 | + builder.create<::imex::ptensor::InsertSliceOp>(loc, av, bv, offsV, sizesV, stridesV); |
| 226 | + // ... and use av as to later create the ptensor |
| 227 | + dm.addVal(this->guid(), av, |
| 228 | + [this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides, |
| 229 | + uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) { |
| 230 | + this->set_value(Registry::get(this->_a).get()); |
| 231 | + // this->set_value(std::move(mk_tnsr(dtype, rank, allocated, aligned, offset, sizes, strides, |
| 232 | + // gs_allocated, gs_aligned, lo_allocated, lo_aligned))); |
| 233 | + }); |
| 234 | + return false; |
| 235 | + } |
| 236 | + |
205 | 237 | FactoryId factory() const |
206 | 238 | { |
207 | 239 | return F_SETITEM; |
@@ -264,7 +296,7 @@ struct DeferredGetItem : public Deferred |
264 | 296 | sizesV, |
265 | 297 | stridesV), |
266 | 298 | [this, dtype](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides, |
267 | | - uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) { |
| 299 | + uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) { |
268 | 300 | this->set_value(std::move(mk_tnsr(dtype, rank, allocated, aligned, offset, sizes, strides, |
269 | 301 | gs_allocated, gs_aligned, lo_allocated, lo_aligned))); |
270 | 302 | }); |
|
0 commit comments