-
Notifications
You must be signed in to change notification settings - Fork 59
Add comm engine attrs for collective and async ops #723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1737,6 +1737,7 @@ def BuildAsyncSessionOp : PTO_Op<"comm.build_async_session", [ | |
| let arguments = (ins | ||
| TileBufOrMemRef:$scratch, | ||
| ScalarPtrOrMemRef:$workspace, | ||
| DefaultValuedAttr<PTO_DmaEngineAttr, "::mlir::pto::DmaEngine::SDMA">:$dmaEngine, | ||
| OptionalAttr<I32Attr>:$sync_id, | ||
| OptionalAttr<I64Attr>:$block_bytes, | ||
| OptionalAttr<I64Attr>:$comm_block_offset, | ||
|
|
@@ -1761,7 +1762,8 @@ def TPutAsyncOp : PTO_Op<"comm.tput_async", [ | |
| let arguments = (ins | ||
| PTODpsType:$dst, | ||
| PTODpsType:$src, | ||
| AsyncSessionType:$session | ||
| AsyncSessionType:$session, | ||
| DefaultValuedAttr<PTO_DmaEngineAttr, "::mlir::pto::DmaEngine::SDMA">:$dmaEngine | ||
| ); | ||
|
|
||
| let results = (outs AsyncEventType:$event); | ||
|
|
@@ -1782,7 +1784,8 @@ def TGetAsyncOp : PTO_Op<"comm.tget_async", [ | |
| let arguments = (ins | ||
| PTODpsType:$dst, | ||
| PTODpsType:$src, | ||
| AsyncSessionType:$session | ||
| AsyncSessionType:$session, | ||
| DefaultValuedAttr<PTO_DmaEngineAttr, "::mlir::pto::DmaEngine::SDMA">:$dmaEngine | ||
| ); | ||
|
|
||
| let results = (outs AsyncEventType:$event); | ||
|
|
@@ -1933,6 +1936,7 @@ def TBroadcastOp : PTO_Op<"comm.tbroadcast", [ | |
| PTODpsType:$ping, | ||
| Optional<PTODpsType>:$pong, | ||
| Variadic<PTODpsType>:$group, | ||
| DefaultValuedAttr<PTO_CollEngineAttr, "::mlir::pto::CollEngine::AIV">:$collEngine, | ||
| I32Attr:$root | ||
| ); | ||
| let results = (outs); | ||
|
|
@@ -1950,6 +1954,7 @@ def CommTGatherOp : PTO_Op<"comm.tgather", [ | |
| PTODpsType:$ping, | ||
| Optional<PTODpsType>:$pong, | ||
| Variadic<PTODpsType>:$group, | ||
| DefaultValuedAttr<PTO_CollEngineAttr, "::mlir::pto::CollEngine::AIV">:$collEngine, | ||
| I32Attr:$root | ||
|
Comment on lines
+1957
to
1958
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inserting the new |
||
| ); | ||
| let results = (outs); | ||
|
|
@@ -1967,6 +1972,7 @@ def CommTScatterOp : PTO_Op<"comm.tscatter", [ | |
| PTODpsType:$ping, | ||
| Optional<PTODpsType>:$pong, | ||
| Variadic<PTODpsType>:$group, | ||
| DefaultValuedAttr<PTO_CollEngineAttr, "::mlir::pto::CollEngine::AIV">:$collEngine, | ||
| I32Attr:$root | ||
|
Comment on lines
+1975
to
1976
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inserting the new |
||
| ); | ||
| let results = (outs); | ||
|
|
@@ -1985,6 +1991,7 @@ def TReduceOp : PTO_Op<"comm.treduce", [ | |
| PTODpsType:$recvPing, | ||
| Optional<PTODpsType>:$recvPong, | ||
| Variadic<PTODpsType>:$group, | ||
| DefaultValuedAttr<PTO_CollEngineAttr, "::mlir::pto::CollEngine::AIV">:$collEngine, | ||
| PTO_ReduceOpAttr:$reduceOp, | ||
| I32Attr:$root | ||
|
Comment on lines
+1994
to
1996
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inserting the new |
||
| ); | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -198,6 +198,16 @@ static void bindPTOModule(pybind11::module &m) { | |||||||||
| .value("Min", mlir::pto::ReduceOp::Min) | ||||||||||
| .export_values(); | ||||||||||
|
|
||||||||||
| py::enum_<mlir::pto::DmaEngine>(m, "DmaEngine") | ||||||||||
| .value("SDMA", mlir::pto::DmaEngine::SDMA) | ||||||||||
| .value("URMA", mlir::pto::DmaEngine::URMA) | ||||||||||
| .export_values(); | ||||||||||
|
|
||||||||||
| py::enum_<mlir::pto::CollEngine>(m, "CollEngine") | ||||||||||
| .value("AIV", mlir::pto::CollEngine::AIV) | ||||||||||
| .value("CCU", mlir::pto::CollEngine::CCU) | ||||||||||
| .export_values(); | ||||||||||
|
|
||||||||||
| py::enum_<mlir::pto::SyncOpType>(m, "SyncOpType") | ||||||||||
| .value("TLOAD", mlir::pto::SyncOpType::TLOAD) | ||||||||||
| .value("TSTORE_ACC", mlir::pto::SyncOpType::TSTORE_ACC) | ||||||||||
|
|
@@ -363,6 +373,32 @@ static void bindPTOModule(pybind11::module &m) { | |||||||||
| return cls(a); | ||||||||||
| }, | ||||||||||
| py::arg("cls"), py::arg("value"), py::arg("context") = py::none()); | ||||||||||
|
|
||||||||||
| mlir_attribute_subclass(m, "DmaEngineAttr", | ||||||||||
| [](MlirAttribute a) -> bool { | ||||||||||
| return mlirPTOAttrIsADmaEngineAttr(a); | ||||||||||
| }) | ||||||||||
| .def_classmethod( | ||||||||||
| "get", | ||||||||||
| [](py::object cls, mlir::pto::DmaEngine value, MlirContext ctx) -> py::object { | ||||||||||
| MlirAttribute a = mlirPTODmaEngineAttrGet(ctx, static_cast<int32_t>(value)); | ||||||||||
|
Comment on lines
+383
to
+384
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||
| if (mlirAttributeIsNull(a)) return py::none(); | ||||||||||
| return cls(a); | ||||||||||
| }, | ||||||||||
| py::arg("cls"), py::arg("value"), py::arg("context") = py::none()); | ||||||||||
|
|
||||||||||
| mlir_attribute_subclass(m, "CollEngineAttr", | ||||||||||
| [](MlirAttribute a) -> bool { | ||||||||||
| return mlirPTOAttrIsACollEngineAttr(a); | ||||||||||
| }) | ||||||||||
| .def_classmethod( | ||||||||||
| "get", | ||||||||||
| [](py::object cls, mlir::pto::CollEngine value, MlirContext ctx) -> py::object { | ||||||||||
| MlirAttribute a = mlirPTOCollEngineAttrGet(ctx, static_cast<int32_t>(value)); | ||||||||||
|
Comment on lines
+396
to
+397
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||
| if (mlirAttributeIsNull(a)) return py::none(); | ||||||||||
| return cls(a); | ||||||||||
| }, | ||||||||||
| py::arg("cls"), py::arg("value"), py::arg("context") = py::none()); | ||||||||||
| // [保留 HEAD]: AddressSpaceAttr 定义 | ||||||||||
| mlir_attribute_subclass( | ||||||||||
| m, "AddressSpaceAttr", | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inserting the new
collEngineattribute before the existingrootattribute in the TableGenargumentslist breaks backward compatibility for both C++ and Python APIs. Any existing code callingTBroadcastOpwith positional arguments will now have their arguments mismatched (e.g., passing the root rank tocollEngine). Placing new optional or default-valued attributes at the end of theargumentslist preserves the positional argument order and maintains backward compatibility.