Skip to content

[release/2.11] [ROCm] fix triu/tril for 64-bit indexing for large matrices#3252

Open
dnikolaev-amd wants to merge 1 commit into
release/2.11from
dnikolaev/rocm_triangular_kernel_fix_rel2.11
Open

[release/2.11] [ROCm] fix triu/tril for 64-bit indexing for large matrices#3252
dnikolaev-amd wants to merge 1 commit into
release/2.11from
dnikolaev/rocm_triangular_kernel_fix_rel2.11

Conversation

@dnikolaev-amd
Copy link
Copy Markdown

@dnikolaev-amd dnikolaev-amd commented May 21, 2026

This PR fixes triu/tril kernel on ROCm for 64-bit indexed large tensors

ROCm limits the total number of threads (grid_size*block_size) <= 2^32. And we have incorrect result or error if this limit is exceeded A strided loop with limited grid size is used on ROCm to stay within the limit

Fixes pytorch#165966

A benchmark was created to identify the optimal grid size and elements-per-thread configuration for maximum performance across all dtypes on MI300X. The shapes were taken from the kernel optimization PR pytorch#115013

I ran the benchmark before and after my changes. The results are as follows:

  • Fixed the triangular kernel on ROCm for 64-bit indexing large tensors
  • Significant performance improvement in many shapes for 32- and 64-bit data types
  • Minor performance regression for some shapes for 8- and 16-bit data types
  • Significant performance degradation on very large shapes for 8- and 64-bit data types

Speedup results (k=0):

mode shape float16 float32 float64 int8
in_place (1, 8192, 8192) 121% 175% 277% 111%
in_place (3072, 3072) 145% 212% 277% 107%
in_place (1, 3072, 3072) 144% 178% 272% 94%
in_place (1, 1, 3072, 3072) 134% 156% 262% 86%
in_place (4, 1024, 1024) 116% 138% 193% 100%
in_place (4, 1021, 1021) 102% 141% 219% 109%
in_place (256, 128, 256) 124% 162% 296% 77%
in_place (128, 257, 125) 105% 128% 168% 120%
in_place (20480, 16, 16) 108% 135% 209% 102%
in_place (40000, 40000) 179% 228% 305% 116%
in_place (80000, 80000) 146% 166% 88% 89%
out_of_place (1, 8192, 8192) 89% 114% 135% 89%
out_of_place (3072, 3072) 96% 159% 169% 120%
out_of_place (1, 3072, 3072) 96% 123% 175% 103%
out_of_place (1, 1, 3072, 3072) 99% 103% 176% 91%
out_of_place (4, 1024, 1024) 102% 114% 149% 105%
out_of_place (4, 1021, 1021) 96% 114% 148% 114%
out_of_place (256, 128, 256) 94% 123% 148% 103%
out_of_place (128, 257, 125) 103% 114% 143% 118%
out_of_place (20480, 16, 16) 99% 112% 126% 108%
out_of_place (40000, 40000) 91% 111% 135% 80%
out_of_place (80000, 80000) 84% 102% 46% 70%

benchmark:

iimport torch
import itertools
import pandas as pd

def in_place(t, k = 0):
        t.triu_(k)
        return t

def out_of_place(t, k = 0):
        return t.triu(k)

col_names = ["mode","shape", "k", "float16", "float32", "float64", "int8"]
dtype_list = [torch.float64, torch.float32, torch.float16, torch.int8]
shape_list = [
        (1, 8192, 8192),
        (3072, 3072),
        (1, 3072, 3072),
        (1, 1, 3072, 3072),
        (4, 1024, 1024),
        (4, 1021, 1021),
        (256, 128, 256),
        (128, 257, 125),
        (20480, 16, 16),
        (40000, 40000),
        # (80000, 80000),
        # (100000, 100000),
    ]
fn_list = [in_place, out_of_place]

rows = []

for fn, shape in itertools.product(fn_list, shape_list):
    for k in [0, shape[-1] // 2, -shape[-1] // 2]:
        row = {
            "mode": fn.__name__,
            "shape": str(shape),
            "k": "\t0" if k == 0 else "\t+1/2" if k > 0 else "\t-1/2",
            }
        for dtype in dtype_list:
            t = torch.empty(shape, device="cuda", dtype=dtype).fill_(-1)

            # validate
            t_h = t.cpu()
            out = fn(t, k)
            out_h = fn(t_h, k)
            torch.testing.assert_close(out, out_h, check_device=False)

            for _ in range(50): #warmup
                fn(t, k)

            torch.cuda.synchronize()
            prof = torch.profiler.profile()
            prof.start()

            for _ in range(100):
                fn(t, k)

            torch.cuda.synchronize()
            prof.stop()

            stats = prof.key_averages()
            # get the time of the triu_tril_kernel
            triu_time = round([s for s in stats if "triu_tril_kernel" in s.key][0].device_time, 4)
            row[str(dtype).replace("torch.", "")] = triu_time
        rows.append(row)

df = pd.DataFrame(rows, columns=col_names)
# df.to_csv("stats_triu.csv", index=False)
print(df.to_csv(index=False, sep=";"))

Pull Request resolved: pytorch#179717
Approved by: https://github.com/jerrymannil, https://github.com/jeffdaily

(cherry picked from commit 4ee8611)

…179717)

This PR fixes triu/tril kernel on ROCm for 64-bit indexed large tensors

ROCm limits the total number of threads (grid_size*block_size) <= 2^32. And we have incorrect result or error if this limit is exceeded
A strided loop with limited grid size is used on ROCm to stay within the limit

Fixes pytorch#165966

A benchmark was created to identify the optimal grid size and elements-per-thread configuration for maximum performance across all dtypes on MI300X.
The shapes were taken from the kernel optimization PR pytorch#115013

I ran the benchmark before and after my changes. The results are as follows:
- Fixed the triangular kernel on ROCm for 64-bit indexing large tensors
- Significant performance improvement in many shapes for 32- and 64-bit data types
- Minor performance regression for some shapes for 8- and 16-bit data types
- Significant performance degradation on very large shapes for 8- and 64-bit data types

Speedup results (k=0):
| mode |	shape |	float16 |	float32 |	float64 |	int8 |
|--|--|--|--|--|--|
| in_place	| (1, 8192, 8192) | 121%	| 175%	| 277%	| 111% |
| in_place	| (3072, 3072)	| 145%	| 212%	| 277%	| 107% |
| in_place	| (1, 3072, 3072) | 144%	| 178%	| 272%	| 94% |
| in_place	| (1, 1, 3072, 3072) | 134%	| 156%	| 262%	| 86% |
| in_place	| (4, 1024, 1024)	| 116%	| 138%	| 193%	| 100% |
| in_place	| (4, 1021, 1021)	| 102%	| 141%	| 219%	| 109% |
| in_place	| (256, 128, 256)	| 124%	| 162%	| 296%	| 77% |
| in_place	| (128, 257, 125)	| 105%	| 128%	| 168%	| 120% |
| in_place	| (20480, 16, 16)	| 108%	| 135%	| 209%	| 102% |
| in_place	| (40000, 40000)	| 179%	| 228%	| 305%	| 116% |
| in_place	| (80000, 80000)	| 146%	| 166%	| 88%	| 89% |
| out_of_place	| (1, 8192, 8192)	| 89%	| 114%	| 135%	| 89% |
| out_of_place	| (3072, 3072)	| 96%	| 159%	| 169%	| 120% |
| out_of_place	| (1, 3072, 3072)	| 96%	| 123%	| 175%	| 103% |
| out_of_place	| (1, 1, 3072, 3072)	| 99%	| 103%	| 176%	| 91% |
| out_of_place	| (4, 1024, 1024)	| 102%	| 114%	| 149%	| 105% |
| out_of_place	| (4, 1021, 1021)	| 96%	| 114%	| 148%	| 114% |
| out_of_place	| (256, 128, 256)	| 94%	| 123%	| 148%	| 103% |
| out_of_place	| (128, 257, 125)	| 103%	| 114%	| 143%	| 118% |
| out_of_place	| (20480, 16, 16)	| 99%	| 112%	| 126%	| 108% |
| out_of_place	| (40000, 40000)	| 91%	| 111%	| 135%	| 80% |
| out_of_place	| (80000, 80000)	| 84%	| 102%	| 46%	| 70% |

benchmark:
```python
iimport torch
import itertools
import pandas as pd

def in_place(t, k = 0):
        t.triu_(k)
        return t

def out_of_place(t, k = 0):
        return t.triu(k)

col_names = ["mode","shape", "k", "float16", "float32", "float64", "int8"]
dtype_list = [torch.float64, torch.float32, torch.float16, torch.int8]
shape_list = [
        (1, 8192, 8192),
        (3072, 3072),
        (1, 3072, 3072),
        (1, 1, 3072, 3072),
        (4, 1024, 1024),
        (4, 1021, 1021),
        (256, 128, 256),
        (128, 257, 125),
        (20480, 16, 16),
        (40000, 40000),
        # (80000, 80000),
        # (100000, 100000),
    ]
fn_list = [in_place, out_of_place]

rows = []

for fn, shape in itertools.product(fn_list, shape_list):
    for k in [0, shape[-1] // 2, -shape[-1] // 2]:
        row = {
            "mode": fn.__name__,
            "shape": str(shape),
            "k": "\t0" if k == 0 else "\t+1/2" if k > 0 else "\t-1/2",
            }
        for dtype in dtype_list:
            t = torch.empty(shape, device="cuda", dtype=dtype).fill_(-1)

            # validate
            t_h = t.cpu()
            out = fn(t, k)
            out_h = fn(t_h, k)
            torch.testing.assert_close(out, out_h, check_device=False)

            for _ in range(50): #warmup
                fn(t, k)

            torch.cuda.synchronize()
            prof = torch.profiler.profile()
            prof.start()

            for _ in range(100):
                fn(t, k)

            torch.cuda.synchronize()
            prof.stop()

            stats = prof.key_averages()
            # get the time of the triu_tril_kernel
            triu_time = round([s for s in stats if "triu_tril_kernel" in s.key][0].device_time, 4)
            row[str(dtype).replace("torch.", "")] = triu_time
        rows.append(row)

df = pd.DataFrame(rows, columns=col_names)
# df.to_csv("stats_triu.csv", index=False)
print(df.to_csv(index=False, sep=";"))
```

Pull Request resolved: pytorch#179717
Approved by: https://github.com/jerrymannil, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
@rocm-repo-management-api
Copy link
Copy Markdown

rocm-repo-management-api Bot commented May 21, 2026

Jenkins build for 29807a8db3edb5c0d0b17355fbe4b0e0b5cce350 commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

@dnikolaev-amd dnikolaev-amd changed the title [ROCm] fix triu/tril for 64-bit indexing for large matrices (#179717) [release/2.11] [ROCm] fix triu/tril for 64-bit indexing for large matrices May 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant