Skip to content

Commit 1316395

Browse files
committed
device: synchronize streams to_device for DistArrays
1 parent 7a206d7 commit 1316395

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/TiledArray/device/device_array.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ void to_device(TiledArray::DistArray<TiledArray::Tile<UMT>, Policy> &um_array) {
5353
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(
5454
tile.tensor(), stream);
5555
}
56+
device::sync_madness_task_with(stream);
5657
};
5758

5859
auto &world = um_array.world();
@@ -87,6 +88,11 @@ void to_host(TiledArray::DistArray<TiledArray::Tile<UMT>, Policy> &um_array) {
8788
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(
8889
tile.tensor(), stream);
8990
}
91+
92+
// Synchronize this stream to ensure prefetch completes before task returns
93+
// This prevents race conditions where world.gop.fence() completes before
94+
// all async prefetch operations have finished
95+
device::sync_madness_task_with(stream);
9096
};
9197

9298
auto &world = um_array.world();
@@ -100,6 +106,8 @@ void to_host(TiledArray::DistArray<TiledArray::Tile<UMT>, Policy> &um_array) {
100106
}
101107

102108
world.gop.fence();
109+
// Note: deviceSynchronize() may be redundant after fence() + per-task sync,
110+
// but kept for extra safety to ensure all device operations are complete
103111
DeviceSafeCall(device::deviceSynchronize());
104112
}
105113

0 commit comments

Comments
 (0)