feat: add synchronous communication ops#467
feat: add synchronous communication ops#467FangRui0 wants to merge 4 commits intohw-native-sys:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a suite of synchronous point-to-point and collective communication operations to the PTO dialect, including tput, tget, signal operations (tnotify, twait, ttest), and collectives (tbroadcast, gather, scatter, reduce). The changes encompass IR definitions, documentation, C/Python bindings, memory effect specifications, and EmitC lowering patterns. Feedback focuses on correcting memory effects for staging and accumulation tiles in several operations to ensure accurate side-effect analysis. Additionally, there are suggestions to complete the atomic type support in tput lowering, reduce code duplication in the EmitC conversion patterns, and improve documentation consistency for the new collective operations.
| addEffect(effects, &getAccMutable(), MemoryEffects::Read::get()); | ||
| addEffect(effects, &getRecvPingMutable(), MemoryEffects::Read::get()); |
There was a problem hiding this comment.
There appear to be incorrect memory effects for the acc and recvPing operands:
recvPingis a receive staging tile, so it is written to. Its effect should beMemoryEffects::Write.accis an accumulation tile. It is both read from and written to during the reduction. It should have bothReadandWriteeffects.
addEffect(effects, &getAccMutable(), MemoryEffects::Read::get());
addEffect(effects, &getAccMutable(), MemoryEffects::Write::get());
addEffect(effects, &getRecvPingMutable(), MemoryEffects::Write::get());| if constexpr (std::is_same_v<OpTy, pto::TPutOp>) { | ||
| if (op.getAtomicType() == pto::AtomicType::AtomicAdd) | ||
| actualCallee = "pto::comm::TPUT<pto::AtomicType::AtomicAdd>"; | ||
| } |
There was a problem hiding this comment.
The lowering for pto.tput only handles AtomicType::AtomicAdd explicitly. The pto::atomic_type enum also supports AtomicAnd, AtomicOr, AtomicXor, AtomicMin, and AtomicMax. This implementation is incomplete and will not lower other atomic types correctly. Please extend it to support all variants of pto::AtomicType.
| if constexpr (std::is_same_v<OpTy, pto::TPutOp>) { | |
| if (op.getAtomicType() == pto::AtomicType::AtomicAdd) | |
| actualCallee = "pto::comm::TPUT<pto::AtomicType::AtomicAdd>"; | |
| } | |
| if constexpr (std::is_same_v<OpTy, pto::TPutOp>) { | |
| switch (op.getAtomicType()) { | |
| case pto::AtomicType::AtomicNone: | |
| break; // default callee is correct | |
| case pto::AtomicType::AtomicAdd: | |
| actualCallee = "pto::comm::TPUT<pto::AtomicType::AtomicAdd>"; | |
| break; | |
| // TODO: Add cases for other atomic types like And, Or, Xor, Min, Max | |
| default: | |
| return rewriter.notifyMatchFailure(op, "unsupported atomic type for tput lowering"); | |
| } | |
| } |
docs/PTO_IR_manual.md
Outdated
| ##### `pto.comm_tgather` - Collective Gather | ||
|
|
||
| **Summary:** Communication collective that lowers to `pto::comm::TGATHER(...)`. This op is distinct from tile-level `pto.tgather`. | ||
|
|
||
| **Arguments:** `dst`, `ping`, optional `pong`, variadic `group`, `root` | ||
|
|
||
| **Constraints & Verification:** | ||
|
|
||
| - `group` must be non-empty and all members must have identical types. | ||
| - `dst` element type must match the group element type. | ||
| - `ping` / `pong` must be local VEC tile-like values with matching element type. |
docs/PTO_IR_manual.md
Outdated
| ##### `pto.comm_tscatter` - Collective Scatter | ||
|
|
||
| **Summary:** Communication collective that lowers to `pto::comm::TSCATTER(...)`. This op is distinct from tile-level `pto.tscatter`. | ||
|
|
||
| **Arguments:** `src`, `ping`, optional `pong`, variadic `group`, `root` | ||
|
|
||
| **Constraints & Verification:** | ||
|
|
||
| - `group` must be non-empty and all members must have identical types. | ||
| - `src` element type must match the group element type. | ||
| - `ping` / `pong` must be local VEC tile-like values with matching element type. |
| template <typename CollectiveOp> | ||
| struct PTOCommCollectiveToEmitC : public OpConversionPattern<CollectiveOp> { | ||
| using OpConversionPattern<CollectiveOp>::OpConversionPattern; | ||
|
|
||
| explicit PTOCommCollectiveToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, | ||
| StringRef apiName) | ||
| : OpConversionPattern<CollectiveOp>(typeConverter, ctx), | ||
| apiName(apiName.str()) {} | ||
|
|
||
| LogicalResult matchAndRewrite(CollectiveOp op, typename CollectiveOp::Adaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| auto *ctx = rewriter.getContext(); | ||
| Location loc = op.getLoc(); | ||
|
|
||
| SmallVector<Value> operands; | ||
| std::string lambda = "([&]("; | ||
|
|
||
| auto appendParam = [&](StringRef name) { | ||
| if (lambda.back() != '(') | ||
| lambda += ", "; | ||
| lambda += "auto &"; | ||
| lambda += name.str(); | ||
| }; | ||
|
|
||
| auto appendOperand = [&](Value value, StringRef name) { | ||
| appendParam(name); | ||
| operands.push_back(value); | ||
| }; | ||
|
|
||
| auto buildPong = [&](Value original, Value emitted, StringRef name) -> FailureOr<Value> { | ||
| if (!original) | ||
| return failure(); | ||
| return buildCommTileValue(rewriter, loc, original, emitted); | ||
| }; | ||
|
|
||
| if constexpr (std::is_same_v<CollectiveOp, pto::TBroadcastOp>) { | ||
| FailureOr<Value> srcGT = | ||
| buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), | ||
| op.getOperation()); | ||
| FailureOr<Value> pingTile = | ||
| buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); | ||
| auto groupGTs = | ||
| buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); | ||
| if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) | ||
| return rewriter.notifyMatchFailure(op, "failed to materialize broadcast operands"); | ||
| appendOperand(*srcGT, "__src"); | ||
| appendOperand(*pingTile, "__ping"); | ||
| if (op.getPong()) { | ||
| FailureOr<Value> pongTile = | ||
| buildPong(op.getPong(), adaptor.getPong(), "__pong"); | ||
| if (failed(pongTile)) | ||
| return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); | ||
| appendOperand(*pongTile, "__pong"); | ||
| } | ||
| for (size_t i = 0; i < groupGTs->size(); ++i) | ||
| appendOperand((*groupGTs)[i], ("__g" + Twine(i)).str()); | ||
| lambda += ") { "; | ||
| lambda += "using __GT = std::decay_t<decltype(__g0)>; __GT __group[] = {"; | ||
| for (size_t i = 0; i < groupGTs->size(); ++i) { | ||
| if (i) | ||
| lambda += ", "; | ||
| lambda += "__g" + std::to_string(i); | ||
| } | ||
| lambda += "}; auto __pg = pto::comm::ParallelGroup<__GT>::Create(__group, "; | ||
| lambda += std::to_string(groupGTs->size()) + ", " + std::to_string(op.getRoot()); | ||
| lambda += "); pto::comm::TBROADCAST(__pg, __src, __ping"; | ||
| if (op.getPong()) | ||
| lambda += ", __pong"; | ||
| lambda += "); })"; | ||
| } else if constexpr (std::is_same_v<CollectiveOp, pto::CommTGatherOp>) { | ||
| FailureOr<Value> dstGT = | ||
| buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), | ||
| op.getOperation()); | ||
| FailureOr<Value> pingTile = | ||
| buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); | ||
| auto groupGTs = | ||
| buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); | ||
| if (failed(dstGT) || failed(pingTile) || failed(groupGTs)) | ||
| return rewriter.notifyMatchFailure(op, "failed to materialize gather operands"); | ||
| appendOperand(*dstGT, "__dst"); | ||
| appendOperand(*pingTile, "__ping"); | ||
| if (op.getPong()) { | ||
| FailureOr<Value> pongTile = | ||
| buildPong(op.getPong(), adaptor.getPong(), "__pong"); | ||
| if (failed(pongTile)) | ||
| return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); | ||
| appendOperand(*pongTile, "__pong"); | ||
| } | ||
| for (size_t i = 0; i < groupGTs->size(); ++i) | ||
| appendOperand((*groupGTs)[i], ("__g" + Twine(i)).str()); | ||
| lambda += ") { using __GT = std::decay_t<decltype(__g0)>; __GT __group[] = {"; | ||
| for (size_t i = 0; i < groupGTs->size(); ++i) { | ||
| if (i) | ||
| lambda += ", "; | ||
| lambda += "__g" + std::to_string(i); | ||
| } | ||
| lambda += "}; auto __pg = pto::comm::ParallelGroup<__GT>::Create(__group, "; | ||
| lambda += std::to_string(groupGTs->size()) + ", " + std::to_string(op.getRoot()); | ||
| lambda += "); pto::comm::TGATHER(__pg, __dst, __ping"; | ||
| if (op.getPong()) | ||
| lambda += ", __pong"; | ||
| lambda += "); })"; | ||
| } else if constexpr (std::is_same_v<CollectiveOp, pto::CommTScatterOp>) { | ||
| FailureOr<Value> srcGT = | ||
| buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), | ||
| op.getOperation()); | ||
| FailureOr<Value> pingTile = | ||
| buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); | ||
| auto groupGTs = | ||
| buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); | ||
| if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) | ||
| return rewriter.notifyMatchFailure(op, "failed to materialize scatter operands"); | ||
| appendOperand(*srcGT, "__src"); | ||
| appendOperand(*pingTile, "__ping"); | ||
| if (op.getPong()) { | ||
| FailureOr<Value> pongTile = | ||
| buildPong(op.getPong(), adaptor.getPong(), "__pong"); | ||
| if (failed(pongTile)) | ||
| return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); | ||
| appendOperand(*pongTile, "__pong"); | ||
| } | ||
| for (size_t i = 0; i < groupGTs->size(); ++i) | ||
| appendOperand((*groupGTs)[i], ("__g" + Twine(i)).str()); | ||
| lambda += ") { using __GT = std::decay_t<decltype(__g0)>; __GT __group[] = {"; | ||
| for (size_t i = 0; i < groupGTs->size(); ++i) { | ||
| if (i) | ||
| lambda += ", "; | ||
| lambda += "__g" + std::to_string(i); | ||
| } | ||
| lambda += "}; auto __pg = pto::comm::ParallelGroup<__GT>::Create(__group, "; | ||
| lambda += std::to_string(groupGTs->size()) + ", " + std::to_string(op.getRoot()); | ||
| lambda += "); pto::comm::TSCATTER(__pg, __src, __ping"; | ||
| if (op.getPong()) | ||
| lambda += ", __pong"; | ||
| lambda += "); })"; | ||
| } else { | ||
| FailureOr<Value> dstGT = | ||
| buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), | ||
| op.getOperation()); | ||
| FailureOr<Value> accTile = | ||
| buildCommTileValue(rewriter, loc, op.getAcc(), adaptor.getAcc()); | ||
| FailureOr<Value> recvPing = | ||
| buildCommTileValue(rewriter, loc, op.getRecvPing(), adaptor.getRecvPing()); | ||
| auto groupGTs = | ||
| buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); | ||
| if (failed(dstGT) || failed(accTile) || failed(recvPing) || failed(groupGTs)) | ||
| return rewriter.notifyMatchFailure(op, "failed to materialize reduce operands"); | ||
| appendOperand(*dstGT, "__dst"); | ||
| appendOperand(*accTile, "__acc"); | ||
| appendOperand(*recvPing, "__recv_ping"); | ||
| if (op.getRecvPong()) { | ||
| FailureOr<Value> recvPong = | ||
| buildPong(op.getRecvPong(), adaptor.getRecvPong(), "__recv_pong"); | ||
| if (failed(recvPong)) | ||
| return rewriter.notifyMatchFailure(op, "failed to materialize recv_pong"); | ||
| appendOperand(*recvPong, "__recv_pong"); | ||
| } | ||
| for (size_t i = 0; i < groupGTs->size(); ++i) | ||
| appendOperand((*groupGTs)[i], ("__g" + Twine(i)).str()); | ||
| lambda += ") { using __GT = std::decay_t<decltype(__g0)>; __GT __group[] = {"; | ||
| for (size_t i = 0; i < groupGTs->size(); ++i) { | ||
| if (i) | ||
| lambda += ", "; | ||
| lambda += "__g" + std::to_string(i); | ||
| } | ||
| lambda += "}; auto __pg = pto::comm::ParallelGroup<__GT>::Create(__group, "; | ||
| lambda += std::to_string(groupGTs->size()) + ", " + std::to_string(op.getRoot()); | ||
| lambda += "); pto::comm::TREDUCE(__pg, __dst, __acc, __recv_ping"; | ||
| if (op.getRecvPong()) | ||
| lambda += ", __recv_pong"; | ||
| lambda += ", " + reduceOpTok(op.getReduceOp()) + "); })"; | ||
| } | ||
|
|
||
| rewriter.create<emitc::CallOpaqueOp>(loc, TypeRange{}, lambda, ArrayAttr{}, | ||
| ArrayAttr{}, operands); | ||
| rewriter.eraseOp(op); | ||
| return success(); | ||
| } | ||
|
|
||
| std::string apiName; | ||
| }; |
There was a problem hiding this comment.
The matchAndRewrite method in PTOCommCollectiveToEmitC has a significant amount of duplicated code across the if constexpr branches. The logic for creating the __group C-style array and the pto::comm::ParallelGroup is nearly identical in all branches.
To improve maintainability, consider extracting this common logic into a helper function. This would reduce code duplication and make the lowering logic for each collective op clearer.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
/run a3 |
A3 板测失败
失败用例
|
A3 板测失败详情:PR #467mrgsort_format2
comm_p2p
comm_collective
|
No description provided.