Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ Documentation for TransferBench is available at

## v1.67.00
### Added
- Added NIC_TRAFFIC_CLASS to set the DSCP/traffic class byte in the RoCE GRH for QPs (RoCE only)
- Added NIC_SERVICE_LEVEL to set the IB service level (sl) for QPs (IB and RoCE)
- Added NIC_TRAFFIC_CLASS to set the DSCP/traffic class byte in the RoCE GRH for QPs (equivalent to NCCL_IB_TC)
- Added NIC_TRAFFIC_CLASS_FIFO to set a DSCP/traffic class to steer the control traffic into another priority queue (equivalent to NCCL_IB_FIFO_TC)
- Added NIC_SERVICE_LEVEL to set the IB service level (sl) for QPs (equivalent to NCCL_IB_SL)
- Initial support for pod communication. Requires compatible hardware / ROCm version and subject to further testing
- This potentially enables GFX/DMA executors to access SRC/DST memory locations on GPUs within the same pod
- Pod membership requires amd-smi however can be skipped by setting TB_FORCE_SINGLE_POD=1
Expand Down
20 changes: 17 additions & 3 deletions src/client/EnvVars.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class EnvVars
int nicChunkBytes; // Number of bytes to send per chunk for RDMA operations
int nicCqPollBatch; // Number of CQ entries to poll per ibv_poll_cq call
int nicRelaxedOrder; // Use relaxed ordering for RDMA
int nicFifoTrafficClass; // DSCP/traffic class byte for control (FIFO) QPs
int nicServiceLevel; // IB service level (sl) for InfiniBand QPs
int nicTrafficClass; // DSCP/traffic class byte for RoCE GRH
int roceVersion; // RoCE version number
Expand Down Expand Up @@ -186,11 +187,20 @@ class EnvVars
ipAddressFamily = GetEnvVar("IP_ADDRESS_FAMILY" , 4);
nicChunkBytes = GetEnvVar("NIC_CHUNK_BYTES" , 1073741824);
nicCqPollBatch = GetEnvVar("NIC_CQ_POLL_BATCH" , 4);
nicRelaxedOrder = GetEnvVar("NIC_RELAX_ORDER" , 1);
nicServiceLevel = GetEnvVar("NIC_SERVICE_LEVEL" , 0);
nicTrafficClass = GetEnvVar("NIC_TRAFFIC_CLASS" , 0);
nicFifoTrafficClass = GetEnvVar("NIC_TRAFFIC_CLASS_FIFO", 0);
nicRelaxedOrder = GetEnvVar("NIC_RELAX_ORDER" , 1);
nicServiceLevel = GetEnvVar("NIC_SERVICE_LEVEL" , 0);
nicTrafficClass = GetEnvVar("NIC_TRAFFIC_CLASS" , 0);

// Check that NIC service level and traffic class are in valid ranges
if (nicFifoTrafficClass < 0 || nicFifoTrafficClass > 255) {
printf("[ERROR] NIC_TRAFFIC_CLASS_FIFO must be in range 0..255 (got %d)\n", nicFifoTrafficClass);
exit(1);
}
if (nicFifoTrafficClass != 0 && numIterations <= 0) {
printf("[ERROR] NIC_TRAFFIC_CLASS_FIFO requires NUM_ITERATIONS > 0 (timed/infinite mode is not supported)\n");
exit(1);
}
if (nicServiceLevel < 0 || nicServiceLevel > 15) {
printf("[ERROR] NIC_SERVICE_LEVEL must be in range 0..15 (got %d)\n", nicServiceLevel);
exit(1);
Expand Down Expand Up @@ -379,6 +389,7 @@ class EnvVars
#if NIC_EXEC_ENABLED
printf(" NIC_CHUNK_BYTES - Number of bytes to send at a time using NIC (default = 1GB)\n");
printf(" NIC_CQ_POLL_BATCH - Number of CQ entries to poll per ibv_poll_cq call (default = 4)\n");
printf(" NIC_TRAFFIC_CLASS_FIFO - DSCP/traffic class byte for control (FIFO) QP GRH (default=0)\n");
printf(" NIC_RELAX_ORDER - Set to non-zero to use relaxed ordering\n");
printf(" NIC_SERVICE_LEVEL - IB service level (sl) for InfiniBand QPs (default=0)\n");
printf(" NIC_TRAFFIC_CLASS - DSCP/traffic class byte for RoCE GRH (default=0)\n");
Expand Down Expand Up @@ -520,6 +531,8 @@ class EnvVars
"Polling %d CQ entries per ibv_poll_cq call", nicCqPollBatch);
Print("NIC_RELAX_ORDER", nicRelaxedOrder,
"Using %s ordering for NIC RDMA", nicRelaxedOrder ? "relaxed" : "strict");
Print("NIC_TRAFFIC_CLASS_FIFO", nicFifoTrafficClass,
"RoCE FIFO/ctrl traffic class (DSCP) set to %d", nicFifoTrafficClass);
Print("NIC_SERVICE_LEVEL", nicServiceLevel,
"IB service level (sl) set to %d", nicServiceLevel);
Print("NIC_TRAFFIC_CLASS", nicTrafficClass,
Expand Down Expand Up @@ -744,6 +757,7 @@ class EnvVars
cfg.nic.ibGidIndex = ibGidIndex;
cfg.nic.ibPort = ibPort;
cfg.nic.ipAddressFamily = ipAddressFamily;
cfg.nic.fifoTrafficClass = nicFifoTrafficClass;
cfg.nic.useRelaxedOrder = nicRelaxedOrder;
cfg.nic.serviceLevel = nicServiceLevel;
cfg.nic.trafficClass = nicTrafficClass;
Expand Down
185 changes: 171 additions & 14 deletions src/header/TransferBench.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ namespace TransferBench
int queueSize = 100; ///< Completion queue size
int roceVersion = 2; ///< RoCE version (used for auto GID detection)
int useRelaxedOrder = 1; ///< Use relaxed ordering
uint8_t fifoTrafficClass = 0; ///< DSCP/traffic class byte for control (FIFO) QPs
uint8_t serviceLevel = 0; ///< IB service level (sl) for InfiniBand QPs
uint8_t trafficClass = 0; ///< DSCP/traffic class byte for RoCE GRH
int useNuma = 0; ///< Switch to closest numa thread for execution
Expand Down Expand Up @@ -2001,9 +2002,10 @@ namespace {
if (nic.maxRecvWorkReq != cfg.nic.maxRecvWorkReq) ADD_ERROR("cfg.nic.maxRecvWorkReq");
if (nic.maxSendWorkReq != cfg.nic.maxSendWorkReq) ADD_ERROR("cfg.nic.maxSendWorkReq");
// nic.queueSize is permitted to be different across ranks
if (nic.roceVersion != cfg.nic.roceVersion) ADD_ERROR("cfg.nic.roceVersion");
if (nic.serviceLevel != cfg.nic.serviceLevel) ADD_ERROR("cfg.nic.serviceLevel");
if (nic.trafficClass != cfg.nic.trafficClass) ADD_ERROR("cfg.nic.trafficClass");
if (nic.roceVersion != cfg.nic.roceVersion) ADD_ERROR("cfg.nic.roceVersion");
if (nic.fifoTrafficClass != cfg.nic.fifoTrafficClass) ADD_ERROR("cfg.nic.fifoTrafficClass");
if (nic.serviceLevel != cfg.nic.serviceLevel) ADD_ERROR("cfg.nic.serviceLevel");
if (nic.trafficClass != cfg.nic.trafficClass) ADD_ERROR("cfg.nic.trafficClass");
if (nic.useRelaxedOrder != cfg.nic.useRelaxedOrder) ADD_ERROR("cfg.nic.useRelaxedOrder");
if (nic.useNuma != cfg.nic.useNuma) ADD_ERROR("cfg.nic.useNuma");
}
Expand Down Expand Up @@ -2749,6 +2751,11 @@ namespace {
ibv_gid dstGid; ///< GID handle for DST NIC
vector<ibv_qp*> srcQueuePairs; ///< Queue pairs for SRC NIC
vector<ibv_qp*> dstQueuePairs; ///< Queue pairs for DST NIC
ibv_cq* srcCtrlCompQueue; ///< Completion queue for SRC ctrl QPs (FIFO TC)
ibv_cq* dstCtrlCompQueue; ///< Completion queue for DST ctrl QPs (FIFO TC)
vector<ibv_qp*> srcCtrlQueuePairs; ///< Control QPs on SRC NIC (FIFO TC)
vector<ibv_qp*> dstCtrlQueuePairs; ///< Control QPs on DST NIC (FIFO TC)
ibv_send_wr ctrlSendWr; ///< Send WR for ctrl signal (zero-byte inline, reused per iteration)
ibv_mr* srcMemRegion; ///< Memory region for SRC
ibv_mr* dstMemRegion; ///< Memory region for DST
int srcDmabufFd; ///< DMA-BUF file descriptor for SRC (if using dmabuf)
Expand Down Expand Up @@ -3230,17 +3237,18 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
static ErrResult CreateQueuePair(ConfigOptions const& cfg,
struct ibv_pd* pd,
struct ibv_cq* cq,
struct ibv_qp*& qp)
struct ibv_qp*& qp,
int maxRecvWr = -1)
{
// Set queue pair attributes
struct ibv_qp_init_attr attr = {};
attr.qp_type = IBV_QPT_RC; // Set type to reliable connection
attr.send_cq = cq; // Send completion queue
attr.recv_cq = cq; // Recv completion queue
attr.cap.max_send_wr = cfg.nic.maxSendWorkReq; // Max send work requests
attr.cap.max_recv_wr = cfg.nic.maxRecvWorkReq; // Max recv work requests
attr.cap.max_send_sge = 1; // Max send scatter-gather entries
attr.cap.max_recv_sge = 1; // Max recv scatter-gather entries
attr.qp_type = IBV_QPT_RC; // Set type to reliable connection
attr.send_cq = cq; // Send completion queue
attr.recv_cq = cq; // Recv completion queue
attr.cap.max_send_wr = cfg.nic.maxSendWorkReq; // Max send work requests
attr.cap.max_recv_wr = (maxRecvWr >= 0) ? maxRecvWr : cfg.nic.maxRecvWorkReq; // Max recv work requests
attr.cap.max_send_sge = 1; // Max send scatter-gather entries
attr.cap.max_recv_sge = 1; // Max recv scatter-gather entries

qp = ibv_create_qp(pd, &attr);
if (qp == NULL)
Expand Down Expand Up @@ -3554,6 +3562,37 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
rss.sendWorkRequests.resize(rss.qpCount);
}

// Create control (FIFO) QPs when fifoTrafficClass is non-zero and this is an RDMA_WRITE
// transfer (srcIsExeNic == true). RDMA_READ transfers never use ctrl QPs so skip them.
// Compute ctrlNumPrepost in 64-bit to avoid signed overflow for large iteration counts,
// then clamp to INT_MAX before casting (NIC_TRAFFIC_CLASS_FIFO already rejects numIterations<=0).
int64_t ctrlNumPrepost64 = (int64_t)(cfg.general.numWarmups + cfg.general.numIterations) *
std::max(1, cfg.general.numSubIterations);
int ctrlNumPrepost = (int)std::min(ctrlNumPrepost64, (int64_t)INT_MAX);
if (cfg.nic.fifoTrafficClass != 0 && rss.srcIsExeNic) {
if (GetRank() == srcMemRank) {
IBV_PTR_CALL(rss.srcCtrlCompQueue, ibv_create_cq,
rss.srcContext, rss.qpCount, NULL, NULL, 0);
rss.srcCtrlQueuePairs.resize(rss.qpCount);
Comment thread
paklui marked this conversation as resolved.
for (int i = 0; i < rss.qpCount; i++) {
// SRC ctrl QP only posts sends, so use default max_recv_wr
ERR_CHECK(CreateQueuePair(cfg, rss.srcProtect, rss.srcCtrlCompQueue, rss.srcCtrlQueuePairs[i]));
ERR_CHECK(InitQueuePair(rss.srcCtrlQueuePairs[i], port, rdmaAccessFlags));
}
}
if (GetRank() == dstMemRank) {
IBV_PTR_CALL(rss.dstCtrlCompQueue, ibv_create_cq,
rss.dstContext, ctrlNumPrepost * rss.qpCount, NULL, NULL, 0);
rss.dstCtrlQueuePairs.resize(rss.qpCount);
Comment thread
paklui marked this conversation as resolved.
for (int i = 0; i < rss.qpCount; i++) {
// DST ctrl QP pre-posts ctrlNumPrepost recv WRs; ensure max_recv_wr is large enough
ERR_CHECK(CreateQueuePair(cfg, rss.dstProtect, rss.dstCtrlCompQueue, rss.dstCtrlQueuePairs[i],
ctrlNumPrepost));
ERR_CHECK(InitQueuePair(rss.dstCtrlQueuePairs[i], port, rdmaAccessFlags));
}
}
}

// Broadcast SRC/DST port link_layer so that all ranks know it so that they can be compared
System::Get().Broadcast(srcMemRank, sizeof(rss.srcPortAttr.link_layer), &rss.srcPortAttr.link_layer);
System::Get().Broadcast(dstMemRank, sizeof(rss.dstPortAttr.link_layer), &rss.dstPortAttr.link_layer);
Expand All @@ -3562,6 +3601,11 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
rss.srcNicIndex, srcMemRank, rss.dstNicIndex, dstMemRank, rss.srcPortAttr.link_layer, rss.dstPortAttr.link_layer};
}

// Shared result type for broadcasting QP transition success/failure across MPI ranks
struct QpTransitionResult { ErrType errType; bool rtrFailed; };
static_assert(std::is_trivially_copyable<QpTransitionResult>::value,
"QpTransitionResult must be trivially copyable for MPI broadcast");

ConnInfo dstConnInfo = {};
ConnInfo srcConnInfo = {};
for (int i = 0; i < rss.qpCount; i++) {
Expand Down Expand Up @@ -3591,8 +3635,6 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
// Then move them to read-to-send (RTS)
// Broadcast each rank's result so all ranks fail together rather than
// hanging on the next iteration's Broadcast when qpCount > 1.
struct QpTransitionResult { ErrType errType; bool rtrFailed; };
static_assert(std::is_trivially_copyable<QpTransitionResult>::value, "QpTransitionResult must be trivially copyable for MPI broadcast");
QpTransitionResult srcQpResult = {ERR_NONE, false};
if (GetRank() == srcMemRank) {
ErrResult err = TransitionQpToRtr(rss.srcQueuePairs[i], dstConnInfo, port, srcIsRoCE, rss.srcPortAttr.active_mtu, cfg.nic.trafficClass, cfg.nic.serviceLevel);
Expand Down Expand Up @@ -3675,6 +3717,86 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
}
}
}

// Exchange ctrl QP numbers and connect them with fifoTrafficClass.
// Guard on srcIsExeNic to match the ctrl QP creation block above.
if (cfg.nic.fifoTrafficClass != 0 && rss.srcIsExeNic) {
ConnInfo srcCtrlInfo = {}, dstCtrlInfo = {};
for (int i = 0; i < rss.qpCount; i++) {
if (GetRank() == srcMemRank) {
srcCtrlInfo.lid = rss.srcPortAttr.lid;
srcCtrlInfo.gid = rss.srcGid;
srcCtrlInfo.gidIdx = srcGidIndex;
srcCtrlInfo.qpn = rss.srcCtrlQueuePairs[i]->qp_num;
srcCtrlInfo.rkey = 0;
srcCtrlInfo.vaddr = 0;
}
System::Get().Broadcast(srcMemRank, sizeof(srcCtrlInfo), &srcCtrlInfo);

if (GetRank() == dstMemRank) {
dstCtrlInfo.lid = rss.dstPortAttr.lid;
dstCtrlInfo.gid = rss.dstGid;
dstCtrlInfo.gidIdx = dstGidIndex;
dstCtrlInfo.qpn = rss.dstCtrlQueuePairs[i]->qp_num;
dstCtrlInfo.rkey = 0;
dstCtrlInfo.vaddr = 0;
}
System::Get().Broadcast(dstMemRank, sizeof(dstCtrlInfo), &dstCtrlInfo);

QpTransitionResult srcCtrlResult = {ERR_NONE, false};
if (GetRank() == srcMemRank) {
ErrResult err = TransitionQpToRtr(rss.srcCtrlQueuePairs[i], dstCtrlInfo, port, srcIsRoCE,
rss.srcPortAttr.active_mtu, cfg.nic.fifoTrafficClass, cfg.nic.serviceLevel);
srcCtrlResult.rtrFailed = (err.errType != ERR_NONE);
if (err.errType == ERR_NONE) err = TransitionQpToRts(rss.srcCtrlQueuePairs[i]);
srcCtrlResult.errType = err.errType;
}
System::Get().Broadcast(srcMemRank, sizeof(srcCtrlResult), &srcCtrlResult);
if (srcCtrlResult.errType != ERR_NONE)
return {ERR_FATAL, "SRC rank %d failed to transition ctrl QP %d to %s",
srcMemRank, i, srcCtrlResult.rtrFailed ? "RTR" : "RTS"};

QpTransitionResult dstCtrlResult = {ERR_NONE, false};
if (GetRank() == dstMemRank) {
ErrResult err = TransitionQpToRtr(rss.dstCtrlQueuePairs[i], srcCtrlInfo, port, dstIsRoCE,
rss.dstPortAttr.active_mtu, cfg.nic.fifoTrafficClass, cfg.nic.serviceLevel);
dstCtrlResult.rtrFailed = (err.errType != ERR_NONE);
if (err.errType == ERR_NONE) err = TransitionQpToRts(rss.dstCtrlQueuePairs[i]);
dstCtrlResult.errType = err.errType;
}
System::Get().Broadcast(dstMemRank, sizeof(dstCtrlResult), &dstCtrlResult);
if (dstCtrlResult.errType != ERR_NONE)
return {ERR_FATAL, "DST rank %d failed to transition ctrl QP %d to %s",
dstMemRank, i, dstCtrlResult.rtrFailed ? "RTR" : "RTS"};
}

// Pre-build reusable ctrl send WR (zero-byte inline IBV_WR_SEND, posted once per iteration)
rss.ctrlSendWr = {};
rss.ctrlSendWr.sg_list = nullptr;
rss.ctrlSendWr.num_sge = 0;
rss.ctrlSendWr.opcode = IBV_WR_SEND;
rss.ctrlSendWr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE;

// DST rank pre-posts recv WRs for all expected iterations so no per-iteration
// coordination (barrier) is needed between executor and non-executor ranks.
// ctrlNumPrepost was computed above and matches the max_recv_wr used when creating ctrl QPs.
if (GetRank() == dstMemRank) {
ibv_recv_wr ctrlRecvWr = {};
ctrlRecvWr.sg_list = nullptr;
ctrlRecvWr.num_sge = 0;
ibv_recv_wr* badRecvWr;
for (int i = 0; i < rss.qpCount; i++) {
for (int n = 0; n < ctrlNumPrepost; n++) {
ibv_recv_wr wr = ctrlRecvWr;
int err = ibv_post_recv(rss.dstCtrlQueuePairs[i], &wr, &badRecvWr);
if (err)
return {ERR_FATAL, "Transfer %d: ibv_post_recv pre-post on ctrl QP %d failed (%s)",
rss.transferIdx, i, strerror(err)};
}
}
}
Comment thread
paklui marked this conversation as resolved.
}

return ERR_NONE;
}

Expand Down Expand Up @@ -3711,6 +3833,18 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
rss.dstQueuePairs.clear();
}

// Destroy ctrl queue pairs and completion queues (only exist when fifoTrafficClass != 0)
if (isSrcRank && !rss.srcCtrlQueuePairs.empty()) {
for (auto qp : rss.srcCtrlQueuePairs) IBV_CALL(ibv_destroy_qp, qp);
rss.srcCtrlQueuePairs.clear();
IBV_CALL(ibv_destroy_cq, rss.srcCtrlCompQueue);
}
if (isDstRank && !rss.dstCtrlQueuePairs.empty()) {
for (auto qp : rss.dstCtrlQueuePairs) IBV_CALL(ibv_destroy_qp, qp);
rss.dstCtrlQueuePairs.clear();
IBV_CALL(ibv_destroy_cq, rss.dstCtrlCompQueue);
}

// Destroy completion queues
if (isSrcRank) IBV_CALL(ibv_destroy_cq, rss.srcCompQueue);
if (isDstRank) IBV_CALL(ibv_destroy_cq, rss.dstCompQueue);
Expand Down Expand Up @@ -4653,7 +4787,30 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
int const exeIndex,
TransferResources& rss)
{
// Loop over each of the queue pairs and post work request
// Ctrl path: executor (src) fires one zero-byte inline IBV_WR_SEND per ctrl QP.
// Recv WRs are pre-posted on the dst side during setup, so no barrier is needed.
if (cfg.nic.fifoTrafficClass != 0 && rss.srcIsExeNic) {
ibv_send_wr* badSendWr;
for (int qpIdx = 0; qpIdx < rss.qpCount; qpIdx++) {
rss.ctrlSendWr.wr_id = qpIdx;
int err = ibv_post_send(rss.srcCtrlQueuePairs[qpIdx], &rss.ctrlSendWr, &badSendWr);
if (err)
Comment thread
paklui marked this conversation as resolved.
return {ERR_FATAL, "Transfer %d: ibv_post_send on ctrl QP %d failed (%s)",
Comment thread
paklui marked this conversation as resolved.
rss.transferIdx, qpIdx, strerror(err)};
}
for (int qpIdx = 0; qpIdx < rss.qpCount; qpIdx++) {
ibv_wc wc;
int nc;
while ((nc = ibv_poll_cq(rss.srcCtrlCompQueue, 1, &wc)) == 0) {}
if (nc < 0)
return {ERR_FATAL, "Transfer %d: ctrl CQ poll error", rss.transferIdx};
if (wc.status != IBV_WC_SUCCESS)
return {ERR_FATAL, "Transfer %d: ctrl send CQ error on QP %llu [status %d]",
rss.transferIdx, wc.wr_id, wc.status};
}
Comment thread
paklui marked this conversation as resolved.
}

// Data path — unchanged: post all RDMA send WRs (stamped with trafficClass)
ibv_send_wr* badWorkReq;
for (int qpIndex = 0; qpIndex < rss.qpCount; qpIndex++) {
size_t numChunks = rss.sendWorkRequests[qpIndex].size();
Expand Down
Loading