feat: lower mgather/mscatter to pto-isa MGATHER/MSCATTER#468
feat: lower mgather/mscatter to pto-isa MGATHER/MSCATTER#468FangRui0 wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the lowering of pto.mgather and pto.mscatter to EmitC by incorporating index support and wrapping global memrefs using buildGlobalTensorFromMemref. The review feedback identifies duplicated logic for memref wrapping in both the gather and scatter conversion patterns and suggests extracting this into a common helper function to improve code maintainability.
| Value memArg = mem; | ||
| if (auto memMrTy = dyn_cast<MemRefType>(op.getMem().getType())) { | ||
| bool isGlobal = true; | ||
| if (auto asAttr = | ||
| dyn_cast_or_null<pto::AddressSpaceAttr>(memMrTy.getMemorySpace())) { | ||
| auto as = asAttr.getAddressSpace(); | ||
| isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); | ||
| } | ||
| if (isGlobal) { | ||
| if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), mem, | ||
| memMrTy, op.getOperation())) | ||
| memArg = gt; | ||
| } | ||
| } |
There was a problem hiding this comment.
This logic for wrapping a global memref with GlobalTensor is duplicated in PTOMScatterToMSCATTER at lines 4798-4811. To improve maintainability and reduce code duplication, consider extracting this into a helper function.
For example, you could define a helper function like this:
static Value maybeWrapGMemWithGlobalTensor(ConversionPatternRewriter &rewriter, Operation *op, Value peeledMem, Value originalMem) {
Value memArg = peeledMem;
if (auto memMrTy = dyn_cast<MemRefType>(originalMem.getType())) {
bool isGlobal = true;
if (auto asAttr = dyn_cast_or_null<pto::AddressSpaceAttr>(memMrTy.getMemorySpace())) {
auto as = asAttr.getAddressSpace();
isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero);
}
if (isGlobal) {
if (Value gt = buildGlobalTensorFromMemref(rewriter, op->getLoc(), peeledMem, memMrTy, op)) {
memArg = gt;
}
}
}
return memArg;
}Then you could replace this block with a call like:
Value memArg = maybeWrapGMemWithGlobalTensor(rewriter, op.getOperation(), mem, op.getMem());
And do the same in PTOMScatterToMSCATTER.
Replace TLOAD/TSTORE fallback with MGATHER(dst, gm, idx) and MSCATTER(gm, src, idx); wrap GM memrefs as GlobalTensor like TLOAD/TSTORE.
Replace TLOAD/TSTORE fallback with MGATHER(dst, gm, idx) and MSCATTER(gm, src, idx); wrap GM memrefs as GlobalTensor like TLOAD/TSTORE.