44#include < ddptensor/DDPTensorImpl.hpp>
55#include < ddptensor/MPITransceiver.hpp>
66
7+ #include < imex/Dialect/PTensor/IR/PTensorOps.h>
8+
79#include < cassert>
810#include < memory>
911
1012using container_type = std::unordered_map<id_type, std::unique_ptr<DDPTensorImpl>>;
1113
1214static container_type gtensors;
15+ static id_type _nguid = -1 ;
16+ inline id_type get_guid ()
17+ {
18+ return ++_nguid;
19+ }
1320
1421// Transceiver * theTransceiver = MPITransceiver();
1522
23+ template <typename T>
24+ T * mr_to_ptr (void * ptr, intptr_t offset)
25+ {
26+ auto mr = reinterpret_cast <intptr_t *>(ptr);
27+ return reinterpret_cast <T*>(ptr) + offset; // &mr.aligned[mr.offset]
28+ }
29+
1630extern " C" {
1731
1832// Register a global tensor of given shape.
19- // Accepts a guid which might have been reserved before. Returns guid (reserved or new) .
33+ // Returns guid.
2034// The runtime does not own or manage any memory.
21- id_t idtr_init_dtensor (const uint64_t * shape, uint64_t N, id_t guid )
35+ id_t idtr_init_dtensor (const uint64_t * shape, uint64_t nD )
2236{
23- assert ( guid != UNKNOWN_GUID );
24- gtensors[guid] = std::unique_ptr<DDPTensorImpl>(new DDPTensorImpl (shape, N) );
37+ auto guid = get_guid ( );
38+ gtensors[guid] = std::unique_ptr<DDPTensorImpl>(nD ? new DDPTensorImpl (shape, nD) : new DDPTensorImpl );
2539 return guid;
2640}
2741
42+ id_t _idtr_init_dtensor (void * alloced, void * aligned, intptr_t offset, intptr_t size, intptr_t stride, uint64_t nD)
43+ {
44+ return idtr_init_dtensor (mr_to_ptr<uint64_t >(aligned, offset), nD);
45+ }
46+
2847// Get the offsets (one for each dimension) of the local partition of a distributed tensor in number of elements.
2948// Result is stored in provided array.
30- void idtr_local_offsets (id_t guid, uint64_t * offsets, uint64_t N )
49+ void idtr_local_offsets (id_t guid, uint64_t * offsets, uint64_t nD )
3150{
3251 const auto & tnsr = gtensors.at (guid);
3352 auto slcs = tnsr->slice ().local_slice ().slices ();
53+ assert (nD == slcs.size ());
3454 int i = -1 ;
3555 for (auto s : slcs) {
3656 offsets[++i] = s._start ;
3757 }
3858}
3959
60+ void _idtr_local_offsets (id_t guid, void * alloced, void * aligned, intptr_t offset, intptr_t size, intptr_t stride, uint64_t nD)
61+ {
62+ idtr_local_offsets (guid, mr_to_ptr<uint64_t >(aligned, offset), nD);
63+ }
64+
4065// Get the shape (one size for each dimension) of the local partition of a distributed tensor in number of elements.
4166// Result is stored in provided array.
4267void idtr_local_shape (id_t guid, uint64_t * lshape, uint64_t N)
@@ -46,10 +71,45 @@ void idtr_local_shape(id_t guid, uint64_t * lshape, uint64_t N)
4671 std::copy (shp.begin (), shp.end (), lshape);
4772}
4873
74+ void _idtr_local_shape (id_t guid, void * alloced, void * aligned, intptr_t offset, intptr_t size, intptr_t stride, uint64_t nD)
75+ {
76+ idtr_local_shape (guid, mr_to_ptr<uint64_t >(aligned, offset), nD);
77+ }
78+
79+ // convert id of our reduction op to id of imex::ptensor reduction op
80+ static ReduceOpId mlir2ddpt (const ::imex::ptensor::ReduceOpId rop)
81+ {
82+ switch (rop) {
83+ case ::imex::ptensor::MEAN:
84+ return MEAN;
85+ case ::imex::ptensor::PROD:
86+ return PROD;
87+ case ::imex::ptensor::SUM:
88+ return SUM;
89+ case ::imex::ptensor::STD:
90+ return STD;
91+ case ::imex::ptensor::VAR:
92+ return VAR;
93+ case ::imex::ptensor::MAX:
94+ return MAX;
95+ case MIN:
96+ return MIN;
97+ default :
98+ throw std::runtime_error (" Unknown reduction operation" );
99+ }
100+ }
101+
49102// Elementwise inplace allreduce
50- void idtr_reduce_all (void * inout, DTypeId dtype, size_t N, RedOpType op)
103+ void idtr_reduce_all (void * inout, DTypeId dtype, uint64_t N, int op)
104+ {
105+
106+ getTransceiver ()->reduce_all (inout, dtype, N, mlir2ddpt (static_cast <imex::ptensor::ReduceOpId>(op)));
107+ }
108+
109+ // FIXME hard-coded 0d tensor
110+ void _idtr_reduce_all (uint64_t * allocated, uint64_t * aligned, uint64_t offset, DTypeId dtype, int op)
51111{
52- getTransceiver ()-> reduce_all (inout , dtype, N , op);
112+ idtr_reduce_all (aligned + offset , dtype, 1 , op);
53113}
54114
55115} // extern "C"
0 commit comments