Skip to content

feat: lower mgather/mscatter to pto-isa MGATHER/MSCATTER#468

Open
FangRui0 wants to merge 1 commit intohw-native-sys:mainfrom
FangRui0:add_op
Open

feat: lower mgather/mscatter to pto-isa MGATHER/MSCATTER#468
FangRui0 wants to merge 1 commit intohw-native-sys:mainfrom
FangRui0:add_op

Conversation

@FangRui0
Copy link
Copy Markdown
Contributor

Replace TLOAD/TSTORE fallback with MGATHER(dst, gm, idx) and MSCATTER(gm, src, idx); wrap GM memrefs as GlobalTensor like TLOAD/TSTORE.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the lowering of pto.mgather and pto.mscatter to EmitC by incorporating index support and wrapping global memrefs using buildGlobalTensorFromMemref. The review feedback identifies duplicated logic for memref wrapping in both the gather and scatter conversion patterns and suggests extracting this into a common helper function to improve code maintainability.

Comment on lines +2363 to +2376
Value memArg = mem;
if (auto memMrTy = dyn_cast<MemRefType>(op.getMem().getType())) {
bool isGlobal = true;
if (auto asAttr =
dyn_cast_or_null<pto::AddressSpaceAttr>(memMrTy.getMemorySpace())) {
auto as = asAttr.getAddressSpace();
isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero);
}
if (isGlobal) {
if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), mem,
memMrTy, op.getOperation()))
memArg = gt;
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for wrapping a global memref with GlobalTensor is duplicated in PTOMScatterToMSCATTER at lines 4798-4811. To improve maintainability and reduce code duplication, consider extracting this into a helper function.

For example, you could define a helper function like this:

static Value maybeWrapGMemWithGlobalTensor(ConversionPatternRewriter &rewriter, Operation *op, Value peeledMem, Value originalMem) {
    Value memArg = peeledMem;
    if (auto memMrTy = dyn_cast<MemRefType>(originalMem.getType())) {
        bool isGlobal = true;
        if (auto asAttr = dyn_cast_or_null<pto::AddressSpaceAttr>(memMrTy.getMemorySpace())) {
            auto as = asAttr.getAddressSpace();
            isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero);
        }
        if (isGlobal) {
            if (Value gt = buildGlobalTensorFromMemref(rewriter, op->getLoc(), peeledMem, memMrTy, op)) {
                memArg = gt;
            }
        }
    }
    return memArg;
}

Then you could replace this block with a call like:
Value memArg = maybeWrapGMemWithGlobalTensor(rewriter, op.getOperation(), mem, op.getMem());

And do the same in PTOMScatterToMSCATTER.

Replace TLOAD/TSTORE fallback with MGATHER(dst, gm, idx) and
MSCATTER(gm, src, idx); wrap GM memrefs as GlobalTensor like TLOAD/TSTORE.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant