@@ -16,8 +16,6 @@ DDPTensorImpl::DDPTensorImpl(DTypeId dtype, uint64_t ndims,
1616 : _owner(owner),
1717 _allocated(allocated),
1818 _aligned(aligned),
19- _sizes(new intptr_t [ndims]),
20- _strides(new intptr_t [ndims]),
2119 _gs_allocated(gs_allocated),
2220 _gs_aligned(gs_aligned),
2321 _lo_allocated(lo_allocated),
@@ -26,8 +24,15 @@ DDPTensorImpl::DDPTensorImpl(DTypeId dtype, uint64_t ndims,
2624 _ndims(ndims),
2725 _dtype(dtype)
2826{
29- memcpy (_sizes, sizes, ndims*sizeof (*_sizes));
30- memcpy (_strides, strides, ndims*sizeof (*_strides));
27+ if (ndims > 0 ) {
28+ _sizes = new intptr_t [ndims];
29+ _strides = new intptr_t [ndims];
30+ memcpy (_sizes, sizes, ndims*sizeof (*_sizes));
31+ memcpy (_strides, strides, ndims*sizeof (*_strides));
32+ } else {
33+ _owner = REPLICATED;
34+ assert (_aligned);
35+ }
3136}
3237
3338DDPTensorImpl::DDPTensorImpl (DTypeId dtype, const shape_type & shp, rank_type owner)
@@ -72,15 +77,17 @@ DDPTensorImpl::ptr_type DDPTensorImpl::clone(bool copy)
7277 gs_allocated, gs_aligned, lo_allocated, lo_aligned, owner ());
7378}
7479
75- void DDPTensorImpl::alloc ()
80+ void DDPTensorImpl::alloc (bool all )
7681{
7782 auto esz = sizeof_dtype (_dtype);
78- _allocated = new (std::align_val_t (esz)) char [esz*size ()];
83+ _allocated = new (std::align_val_t (esz)) char [esz*local_size ()];
7984 _aligned = _allocated;
80- auto nds = ndims ();
81- _sizes = new intptr_t [nds];
82- _strides = new intptr_t [nds];
8385 _offset = 0 ;
86+ if (all) {
87+ auto nds = ndims ();
88+ _sizes = new intptr_t [nds];
89+ _strides = new intptr_t [nds];
90+ }
8491}
8592
8693void * DDPTensorImpl::data ()
@@ -106,8 +113,11 @@ std::string DDPTensorImpl::__repr__() const
106113
107114 dispatch (_dtype, _aligned, [this , nd, &oss](auto * ptr) {
108115 auto cptr = ptr + this ->_offset ;
109- if (nd>0 ) printit (oss, 0 , cptr);
110- else oss << *cptr;
116+ if (nd>0 ) {
117+ printit (oss, 0 , cptr);
118+ } else {
119+ oss << *cptr;
120+ }
111121 });
112122 return oss.str ();
113123}
@@ -189,3 +199,26 @@ void DDPTensorImpl::add_to_args(std::vector<void*> & args, int ndims)
189199 buff[4 ] = 1 ;
190200 args.push_back (buff);
191201}
202+
203+ void DDPTensorImpl::replicate ()
204+ {
205+ if (is_replicated ()) return ;
206+ auto gsz = size ();
207+ auto lsz = local_size ();
208+ if (gsz > 1 ) throw (std::runtime_error (" Replication implemented for single-element tensors only." ));
209+ if (lsz != gsz) {
210+ assert (lsz == 0 );
211+ auto nd = ndims ();
212+ for (auto i=0 ; i<nd; ++i) {
213+ _sizes[i] = _strides[i] = 1 ;
214+ }
215+ _sizes[nd-1 ] = gsz;
216+ }
217+ dispatch (_dtype, _aligned, [this , lsz, gsz](auto * ptr) {
218+ auto tmp = ptr[this ->_offset ];
219+ if (lsz != gsz) ptr[this ->_offset ] = 0 ;
220+ getTransceiver ()->reduce_all (&ptr[this ->_offset ], this ->_dtype , 1 , SUM);
221+ assert (lsz != gsz || tmp == ptr[this ->_offset ]);
222+ });
223+ set_owner (REPLICATED);
224+ }
0 commit comments