Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions docs/PTO_IR_manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -5592,14 +5592,15 @@ For each element (i, j):
|------|------|-------------|
| `src0` | `pto.tile_buf` | Source tile buffer |
| `src1` | `pto.tile_buf` | Per-row scalar vector |
| `tmp` | `pto.tile_buf` (optional) | Optional scratch tile forwarded to the `pto-isa` tmp-buffer overload |
| `dst` | `pto.tile_buf` | Destination tile buffer |

**Results:** None. Writes into `dst` via DPS pattern.

**Assembly Format:**

```
pto.trowexpandadd ins(<src0>, <src1> : <src0_type>, <src1_type>)
pto.trowexpandadd ins(<src0>, <src1> [, <tmp>] : <src0_type>, <src1_type> [, <tmp_type>])
outs(<dst> : <dst_type>)
```

Expand Down Expand Up @@ -7008,7 +7009,7 @@ pto.tscatter ins(%src, {maskPattern = #pto.mask_pattern<P0101>} : !pto.tile_buf<

##### `pto.mgather` - Gather-Load from Global Memory

**Summary:** Loads elements from a global table into a VEC tile using per-element indices. Supports an optional A5-only out-of-bounds mode that lowers to the corresponding `MGATHER<...>` template overload.
**Summary:** Loads elements from a global table into a VEC tile using per-element indices. Supports optional A5-only `coalesce` and `gatherOob` attributes that lower to the corresponding `MGATHER<...>` template overload.

**Semantics:**

Expand All @@ -7024,6 +7025,7 @@ elem mode: dst[i, j] = mem[idx[i, j]]
| `mem` | `!pto.partition_tensor_view<...>` / GM memref | `NA` | Global source table |
| `idx` | `pto.tile_buf` | `NA` | Index tile |
| `dst` | `pto.tile_buf` | `NA` | Destination VEC tile |
| `coalesce` | `#pto<coalesce ...>` | inferred | Explicit coalesce mode (`row` / `elem`) |
| `gatherOob` | `#pto<gather_oob ...>` | `undefined` | A5-only out-of-bounds mode (`undefined/clamp/wrap/zero`) |

**Results:** None. Writes into `dst` via DPS pattern.
Expand All @@ -7049,6 +7051,10 @@ elem mode: dst[i, j] = mem[idx[i, j]]
- **Out-of-bounds mode**
- Default `gatherOob = undefined` lowers to the default `MGATHER(dst, mem, idx)` overload.
- Non-default `gatherOob` values are only supported on **A5** and lower to `MGATHER<GatherOOB::...>(dst, mem, idx)`.
- **Coalesce mode**
- If `coalesce` is omitted, PTOAS preserves the existing inference from the `idx` tile shape/layout.
- `coalesce = #pto<coalesce row>` lowers to `MGATHER<pto::Coalesce::Row, ...>`.
- `coalesce = #pto<coalesce elem>` lowers to `MGATHER<pto::Coalesce::Elem, ...>`.

**Hardware Mapping:**

Expand All @@ -7060,6 +7066,10 @@ elem mode: dst[i, j] = mem[idx[i, j]]
pto.mgather ins(%mem, %idx : memref<...>, !pto.tile_buf<...>)
outs(%dst : !pto.tile_buf<...>)

pto.mgather ins(%mem, %idx : memref<...>, !pto.tile_buf<...>)
outs(%dst : !pto.tile_buf<...>)
{coalesce = #pto<coalesce elem>}

pto.mgather ins(%mem, %idx : memref<...>, !pto.tile_buf<...>)
outs(%dst : !pto.tile_buf<...>)
{gatherOob = #pto<gather_oob zero>}
Expand All @@ -7069,7 +7079,7 @@ pto.mgather ins(%mem, %idx : memref<...>, !pto.tile_buf<...>)

##### `pto.mscatter` - Scatter-Store to Global Memory

**Summary:** Stores elements from a VEC tile into a global table using per-element indices. Supports optional A5-only atomic and out-of-bounds modes that lower to the corresponding `MSCATTER<...>` template overload family.
**Summary:** Stores elements from a VEC tile into a global table using per-element indices. Supports optional A5-only `coalesce`, atomic, out-of-bounds, and conflict-mode attributes that lower to the corresponding `MSCATTER<...>` template overload family.

**Semantics:**

Expand All @@ -7085,8 +7095,10 @@ elem mode: mem[idx[i, j]] = src[i, j]
| `src` | `pto.tile_buf` | `NA` | Source VEC tile |
| `idx` | `pto.tile_buf` | `NA` | Index tile |
| `mem` | `!pto.partition_tensor_view<...>` / GM memref | `NA` | Global destination table |
| `coalesce` | `#pto<coalesce ...>` | inferred | Explicit coalesce mode (`row` / `elem`) |
| `scatterAtomicOp` | `#pto<scatter_atomic_op ...>` | `none` | A5-only atomic mode (`none/add/max/min`) |
| `scatterOob` | `#pto<scatter_oob ...>` | `undefined` | A5-only out-of-bounds mode (`undefined/skip/clamp/wrap`) |
| `scatterConflict` | `#pto<scatter_conflict ...>` | omitted | Optional A5 conflict mode (`default` / `last`) |

**Results:** None. Writes into `mem` via DPS pattern.

Expand Down Expand Up @@ -7117,6 +7129,9 @@ elem mode: mem[idx[i, j]] = src[i, j]
- **Out-of-bounds modes**
- Default `scatterOob = undefined` lowers to the 1-template-parameter `MSCATTER<Atomic>(mem, src, idx)` form when only atomic is specified, or to the default overload when both attrs are default.
- Non-default `scatterOob` values are only supported on **A5** and lower to `MSCATTER<ScatterAtomicOp::..., ScatterOOB::...>(mem, src, idx)`.
- **Coalesce and conflict modes**
- If `coalesce` is omitted, PTOAS preserves the existing inference from the `idx` tile shape/layout.
- `scatterConflict` is only meaningful on A5 and lowers by filling the full `MSCATTER<Coalesce, Atomic, Oob, Conflict>` template parameter list.

**Hardware Mapping:**

Expand All @@ -7136,6 +7151,11 @@ pto.mscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>)
outs(%mem : memref<...>)
{scatterAtomicOp = #pto<scatter_atomic_op add>,
scatterOob = #pto<scatter_oob skip>}

pto.mscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>)
outs(%mem : memref<...>)
{coalesce = #pto<coalesce elem>,
scatterConflict = #pto<scatter_conflict last>}
```

---
Expand Down Expand Up @@ -7282,6 +7302,7 @@ For padded elements: dst = PadVal(dst)
|------|------|-------------|
| `src` | `pto.tile_buf` | Source tile |
| `dst` | `pto.tile_buf` | Destination tile (with pad config) |
| `padValue` | `#pto.pad_value<...>` (optional) | Explicit `TFILLPAD<PadValue>` template argument for `loc=mat`. When present, it must match `dst`'s tile pad configuration. |

**Results:** None. Writes into `dst` via DPS pattern.

Expand All @@ -7290,6 +7311,7 @@ For padded elements: dst = PadVal(dst)
- `dst.pad` must not be `null`.
- `src` and `dst` element sizes must match, and the element size must be `1`, `2`, or `4` bytes.
- `dst.rows/cols` must match `src.rows/cols`.
- If `padValue` is present, `dst` must be `loc=mat` and `padValue` must equal the tile type's `pad`.
- For `loc=mat`, `src` and `dst` must be lowerable to the same `TFILLPAD` tile specialization, i.e. `validShape` and `pad` must be identical.

**Hardware Mapping:**
Expand All @@ -7300,6 +7322,9 @@ For padded elements: dst = PadVal(dst)

```mlir
pto.tfillpad ins(%src : !pto.tile_buf<...>) outs(%dst : !pto.tile_buf<...>)

pto.tfillpad ins(%src : !pto.tile_buf<...>) outs(%dst : !pto.tile_buf<...>)
{padValue = #pto.pad_value<max>}
```

---
Expand Down Expand Up @@ -7628,6 +7653,7 @@ dst[i, j] = S + linear_index(i, j) // or descending if requested
| Name | Type | Description |
|------|------|-------------|
| `S` | `Integer` | Starting value |
| `tmp` | `pto.tile_buf` (optional) | Optional scratch tile forwarded to the `pto-isa` tmp-buffer overload |
| `dst` | `pto.tile_buf` | Destination tile |
| `descending` | `BoolAttr` (default: false) | Generate descending sequence |

Expand All @@ -7648,6 +7674,7 @@ dst[i, j] = S + linear_index(i, j) // or descending if requested

```mlir
pto.tci ins(%start : i32) outs(%dst : !pto.tile_buf<...>)
pto.tci ins(%start, %tmp : i32, !pto.tile_buf<...>) outs(%dst : !pto.tile_buf<...>)
```

---
Expand Down Expand Up @@ -9091,7 +9118,8 @@ print(src)

| Name | Type | Description |
|------|------|-------------|
| `src` | `pto.tile_buf` | Tile to print |
| `src` | `pto.tile_buf` / global-memory view | Tile or global-memory view to print |
| `printFormat` | `i32` (optional, default: `0`) | Print format selector: `0=Width8_Precision4`, `1=Width8_Precision2`, `2=Width10_Precision6` |

**Results:** None.

Expand All @@ -9118,8 +9146,9 @@ print(src)

- **Formatting**:

- Floating-point values: printed as `%6.2f`
- Integer values: printed as `%6d`
- `printFormat = 0`: `Width8_Precision4`
- `printFormat = 1`: `Width8_Precision2`
- `printFormat = 2`: `Width10_Precision6`
- For `GlobalTensor`, due to data size and buffer limitations, only elements within its logical shape (defined by `Shape`) are printed.
- For `tile_buf`, elements outside `valid_shape` are still printed and are marked with a `|` separator when partial validity is specified.

Expand All @@ -9131,6 +9160,7 @@ print(src)

```mlir
pto.tprint ins(%src : !pto.tile_buf<loc=vec, dtype=f16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
pto.tprint ins(%src : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>) {printFormat = 1 : i32}
```

---
Expand Down
20 changes: 20 additions & 0 deletions include/PTO/IR/PTOAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,16 @@ def PTO_GatherOOBAttr : EnumAttr<PTO_Dialect, PTO_GatherOOBEnum, "gather_oob"> {
let summary = "MGATHER out-of-bounds handling mode";
}

def PTO_CoalesceEnum : PTO_I32Enum<
"Coalesce", "PTO MGATHER/MSCATTER coalesce mode", [
I32EnumAttrCase<"Row", 0, "row">,
I32EnumAttrCase<"Elem", 1, "elem">
]>;

def PTO_CoalesceAttr : EnumAttr<PTO_Dialect, PTO_CoalesceEnum, "coalesce"> {
let summary = "MGATHER/MSCATTER coalesce mode";
}

def PTO_ScatterAtomicOpEnum : PTO_I32Enum<
"ScatterAtomicOp", "PTO MSCATTER atomic mode", [
I32EnumAttrCase<"None", 0, "none">,
Expand All @@ -551,6 +561,16 @@ def PTO_ScatterOOBAttr : EnumAttr<PTO_Dialect, PTO_ScatterOOBEnum, "scatter_oob"
let summary = "MSCATTER out-of-bounds handling mode";
}

def PTO_ScatterConflictEnum : PTO_I32Enum<
"ScatterConflict", "PTO MSCATTER scatter conflict handling", [
I32EnumAttrCase<"Default", 0, "default">,
I32EnumAttrCase<"Last", 1, "last">
]>;

def PTO_ScatterConflictAttr : EnumAttr<PTO_Dialect, PTO_ScatterConflictEnum, "scatter_conflict"> {
let summary = "MSCATTER scatter conflict handling mode";
}

def PTO_AccToVecMode_Enum : PTO_I32Enum<"AccToVecMode", "TMOV acc-to-vec mode", [
I32EnumAttrCase<"SingleModeVec0", 0, "single_mode_vec0">,
I32EnumAttrCase<"SingleModeVec1", 1, "single_mode_vec1">,
Expand Down
30 changes: 15 additions & 15 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2661,6 +2661,7 @@ def MGatherOp : PTO_TOp<"mgather", [
PTODpsType:$mem,
PTODpsType:$idx,
PTODpsType:$dst,
OptionalAttr<PTO_CoalesceAttr>:$coalesce,
DefaultValuedAttr<PTO_GatherOOBAttr, "::mlir::pto::GatherOOB::Undefined">:$gatherOob);

let results = (outs);
Expand Down Expand Up @@ -2839,8 +2840,10 @@ def MScatterOp : PTO_TOp<"mscatter", [
PTODpsType:$src,
PTODpsType:$idx,
PTODpsType:$mem, // outs target
OptionalAttr<PTO_CoalesceAttr>:$coalesce,
DefaultValuedAttr<PTO_ScatterAtomicOpAttr, "::mlir::pto::ScatterAtomicOp::None">:$scatterAtomicOp,
DefaultValuedAttr<PTO_ScatterOOBAttr, "::mlir::pto::ScatterOOB::Undefined">:$scatterOob
DefaultValuedAttr<PTO_ScatterOOBAttr, "::mlir::pto::ScatterOOB::Undefined">:$scatterOob,
OptionalAttr<PTO_ScatterConflictAttr>:$scatterConflict
);

let results = (outs);
Expand Down Expand Up @@ -3217,17 +3220,13 @@ def TCIOp : PTO_TOp<"tci", [

let arguments = (ins
AnyInteger:$S,
Optional<PTODpsType>:$tmp,
PTODpsType:$dst,
DefaultValuedAttr<BoolAttr, "false">:$descending
);
let results = (outs);

let assemblyFormat = [{
`ins` `(` $S
attr-dict
`:` type($S) `)`
`outs` `(` $dst `:` qualified(type($dst) ) `)`
}];
let hasCustomAssemblyFormat = 1;

let hasVerifier = 1;

Expand Down Expand Up @@ -4209,7 +4208,8 @@ def TFillPadOp : PTO_TOp<"tfillpad", [

let arguments = (ins
PTODpsType:$src,
PTODpsType:$dst
PTODpsType:$dst,
OptionalAttr<PTO_PadValueAttr>:$padValue
);

let results = (outs);
Expand Down Expand Up @@ -5399,22 +5399,21 @@ def TRowExpandAddOp: PTO_TOp<"trowexpandadd", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "TROWEXPANDADD: Row-wise broadcast add: add a per-row scalar vector src1 to each row of src0.";
let description = [{
pto-isa has overloads with/without tmp; optional tmp for scratch.
}];

let arguments = (ins
PTODpsType:$src0,
PTODpsType:$src1,
Optional<PTODpsType>:$tmp,
PTODpsType:$dst
);

let results = (outs);

let hasVerifier = 1;

let assemblyFormat = [{
`ins` `(` $src0 `,` $src1 `:` qualified(type($src0)) `,` qualified(type($src1)) `)`
`outs` `(` $dst `:` qualified(type($dst) ) `)`
attr-dict
}];
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
Expand Down Expand Up @@ -6284,7 +6283,8 @@ def TPrintOp: PTO_TOp<"tprint", [
let summary = "TPRINT: Print the contents of a Tile or GlobalTensor for debugging purposes directly from device code..";

let arguments = (ins
PTODpsType:$src
PTODpsType:$src,
DefaultValuedOptionalAttr<I32Attr, "0">:$printFormat
);

let results = (outs);
Expand Down
Loading
Loading