Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/cpu/comm/ccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ void inference_all_reduce(torch::Tensor& data, py::object op)

switch (data.scalar_type()) {
case c10::ScalarType::BFloat16: data_size = numel * 2; break;
case c10::ScalarType::Half: data_size = numel * 2; break;
case c10::ScalarType::Float: data_size = numel * 4; break;
default: data_type_fallback = true;
}
Expand Down
86 changes: 75 additions & 11 deletions csrc/cpu/comm/coll_mpi.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,54 @@
// 1 = handoff, 0 = thread-split
#define HANDOFF 0

#include <mpi.h>
#include "coll_mpi.hpp"

#if !HANDOFF // THREAD_SPLIT
#include <omp.h>
std::vector<MPI_Comm> thread_comm;
static bool thread_comm_inited = false;
#endif

void init_mpi(void)
{
MPI_Init(NULL, NULL);
int mpi_inited;
MPI_Initialized(&mpi_inited);
if (!mpi_inited) {
#if HANDOFF
MPI_Init(NULL, NULL);
#else // THREAD_SPLIT
int provided;
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &provided);
#endif
}
init_mpi_thread_comms();
}

//int size, rank;
//MPI_Comm_size(MPI_COMM_WORLD, &size);
//MPI_Comm_rank(MPI_COMM_WORLD, &rank);
void init_mpi_thread_comms(void)
{
#if !HANDOFF
if (!thread_comm_inited) {
MPI_Info info;
char s[16];
int num_threads = omp_get_max_threads();
thread_comm.resize(num_threads);
for (int i = 0; i < num_threads; i++) {
MPI_Comm_dup(MPI_COMM_WORLD, &thread_comm[i]);
snprintf(s, 16, "%d", i);
MPI_Info_create(&info);
MPI_Info_set(info, "thread_id", s);
MPI_Comm_set_info(thread_comm[i], info);
MPI_Info_free(&info);
}
thread_comm_inited = true;
}
#endif
}

/*
char temp_buf[64*1024*1024];

/*
void naive_all_reduce(int world_size, int rank, void* buf, size_t data_size, size_t numel, c10::ScalarType scalar_type)
{
if (rank == 0) {
Expand Down Expand Up @@ -155,19 +191,47 @@ void ring_all_reduce(int world_size, int rank, void* buf, size_t data_size, size

void mpi_all_reduce(int world_size, int rank, void* buf, size_t data_size, size_t numel, c10::ScalarType scalar_type)
{
#if HANDOFF
switch (scalar_type) {
case c10::ScalarType::BFloat16:
//naive_all_reduce(world_size, rank, buf, data_size, numel, scalar_type);
//ring_all_reduce(world_size, rank, buf, data_size, numel, scalar_type);
//rabenseifner_all_reduce(world_size, rank, buf, data_size, numel, scalar_type);
MPI_Allreduce(MPI_IN_PLACE, buf, numel, MPIX_C_BF16, MPI_SUM, MPI_COMM_WORLD);
break;
case c10::ScalarType::Half:
MPI_Allreduce(MPI_IN_PLACE, buf, numel, MPIX_C_FLOAT16, MPI_SUM, MPI_COMM_WORLD);
break;
case c10::ScalarType::Float:
MPI_Allreduce(MPI_IN_PLACE, buf, numel, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD);
//naive_all_reduce(world_size, rank, buf, data_size, numel, scalar_type);
//ring_all_reduce(world_size, rank, buf, data_size, numel, scalar_type);
//rabenseifner_all_reduce(world_size, rank, buf, data_size, numel, scalar_type);
break;
default: assert(!"Should not get here");
}

#else // THREAD_SPLIT

// Could tune number of threads for performance based on numel...
int nthds = std::min((size_t)omp_get_max_threads(), numel);

#pragma omp parallel for num_threads(nthds) schedule(static) shared(nthds, numel)
for (int tid = 0; tid < nthds; tid++)
{
size_t my_numel = numel / nthds;
char *my_buf = (char *)buf + (tid * my_numel);
if (tid == nthds - 1) { // Last thread may have uneven number of elements
my_numel = numel - (my_numel * (nthds - 1)); // Could balance better...
}

switch (scalar_type) {
case c10::ScalarType::BFloat16:
MPI_Allreduce(MPI_IN_PLACE, my_buf, my_numel, MPIX_C_BF16, MPI_SUM, thread_comm[tid]);
break;
case c10::ScalarType::Half:
MPI_Allreduce(MPI_IN_PLACE, my_buf, my_numel, MPIX_C_FLOAT16, MPI_SUM, thread_comm[tid]);
break;
case c10::ScalarType::Float:
MPI_Allreduce(MPI_IN_PLACE, my_buf, my_numel, MPI_FLOAT, MPI_SUM, thread_comm[tid]);
break;
default: assert(!"Should not get here");
}
} // omp parallel for
#endif // THREAD_SPLIT
}

1 change: 1 addition & 0 deletions csrc/cpu/comm/coll_mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/extension.h>

void init_mpi(void);
void init_mpi_thread_comms(void);
void mpi_all_reduce(int world_size, int rank, void* buf, size_t data_size, size_t numel, c10::ScalarType scalar_type);

#endif //_COLL_MPI__HPP_