[release/2.11] [ROCm] fix triu/tril for 64-bit indexing for large matrices#3252
Open
dnikolaev-amd wants to merge 1 commit into
Open
[release/2.11] [ROCm] fix triu/tril for 64-bit indexing for large matrices#3252dnikolaev-amd wants to merge 1 commit into
dnikolaev-amd wants to merge 1 commit into
Conversation
…179717) This PR fixes triu/tril kernel on ROCm for 64-bit indexed large tensors ROCm limits the total number of threads (grid_size*block_size) <= 2^32. And we have incorrect result or error if this limit is exceeded A strided loop with limited grid size is used on ROCm to stay within the limit Fixes pytorch#165966 A benchmark was created to identify the optimal grid size and elements-per-thread configuration for maximum performance across all dtypes on MI300X. The shapes were taken from the kernel optimization PR pytorch#115013 I ran the benchmark before and after my changes. The results are as follows: - Fixed the triangular kernel on ROCm for 64-bit indexing large tensors - Significant performance improvement in many shapes for 32- and 64-bit data types - Minor performance regression for some shapes for 8- and 16-bit data types - Significant performance degradation on very large shapes for 8- and 64-bit data types Speedup results (k=0): | mode | shape | float16 | float32 | float64 | int8 | |--|--|--|--|--|--| | in_place | (1, 8192, 8192) | 121% | 175% | 277% | 111% | | in_place | (3072, 3072) | 145% | 212% | 277% | 107% | | in_place | (1, 3072, 3072) | 144% | 178% | 272% | 94% | | in_place | (1, 1, 3072, 3072) | 134% | 156% | 262% | 86% | | in_place | (4, 1024, 1024) | 116% | 138% | 193% | 100% | | in_place | (4, 1021, 1021) | 102% | 141% | 219% | 109% | | in_place | (256, 128, 256) | 124% | 162% | 296% | 77% | | in_place | (128, 257, 125) | 105% | 128% | 168% | 120% | | in_place | (20480, 16, 16) | 108% | 135% | 209% | 102% | | in_place | (40000, 40000) | 179% | 228% | 305% | 116% | | in_place | (80000, 80000) | 146% | 166% | 88% | 89% | | out_of_place | (1, 8192, 8192) | 89% | 114% | 135% | 89% | | out_of_place | (3072, 3072) | 96% | 159% | 169% | 120% | | out_of_place | (1, 3072, 3072) | 96% | 123% | 175% | 103% | | out_of_place | (1, 1, 3072, 3072) | 99% | 103% | 176% | 91% | | out_of_place | (4, 1024, 1024) | 102% | 114% | 149% | 105% | | out_of_place | (4, 1021, 1021) | 96% | 114% | 148% | 114% | | out_of_place | (256, 128, 256) | 94% | 123% | 148% | 103% | | out_of_place | (128, 257, 125) | 103% | 114% | 143% | 118% | | out_of_place | (20480, 16, 16) | 99% | 112% | 126% | 108% | | out_of_place | (40000, 40000) | 91% | 111% | 135% | 80% | | out_of_place | (80000, 80000) | 84% | 102% | 46% | 70% | benchmark: ```python iimport torch import itertools import pandas as pd def in_place(t, k = 0): t.triu_(k) return t def out_of_place(t, k = 0): return t.triu(k) col_names = ["mode","shape", "k", "float16", "float32", "float64", "int8"] dtype_list = [torch.float64, torch.float32, torch.float16, torch.int8] shape_list = [ (1, 8192, 8192), (3072, 3072), (1, 3072, 3072), (1, 1, 3072, 3072), (4, 1024, 1024), (4, 1021, 1021), (256, 128, 256), (128, 257, 125), (20480, 16, 16), (40000, 40000), # (80000, 80000), # (100000, 100000), ] fn_list = [in_place, out_of_place] rows = [] for fn, shape in itertools.product(fn_list, shape_list): for k in [0, shape[-1] // 2, -shape[-1] // 2]: row = { "mode": fn.__name__, "shape": str(shape), "k": "\t0" if k == 0 else "\t+1/2" if k > 0 else "\t-1/2", } for dtype in dtype_list: t = torch.empty(shape, device="cuda", dtype=dtype).fill_(-1) # validate t_h = t.cpu() out = fn(t, k) out_h = fn(t_h, k) torch.testing.assert_close(out, out_h, check_device=False) for _ in range(50): #warmup fn(t, k) torch.cuda.synchronize() prof = torch.profiler.profile() prof.start() for _ in range(100): fn(t, k) torch.cuda.synchronize() prof.stop() stats = prof.key_averages() # get the time of the triu_tril_kernel triu_time = round([s for s in stats if "triu_tril_kernel" in s.key][0].device_time, 4) row[str(dtype).replace("torch.", "")] = triu_time rows.append(row) df = pd.DataFrame(rows, columns=col_names) # df.to_csv("stats_triu.csv", index=False) print(df.to_csv(index=False, sep=";")) ``` Pull Request resolved: pytorch#179717 Approved by: https://github.com/jerrymannil, https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
|
Jenkins build for 29807a8db3edb5c0d0b17355fbe4b0e0b5cce350 commit finished as FAILURE |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR fixes triu/tril kernel on ROCm for 64-bit indexed large tensors
ROCm limits the total number of threads (grid_size*block_size) <= 2^32. And we have incorrect result or error if this limit is exceeded A strided loop with limited grid size is used on ROCm to stay within the limit
Fixes pytorch#165966
A benchmark was created to identify the optimal grid size and elements-per-thread configuration for maximum performance across all dtypes on MI300X. The shapes were taken from the kernel optimization PR pytorch#115013
I ran the benchmark before and after my changes. The results are as follows:
Speedup results (k=0):
benchmark:
Pull Request resolved: pytorch#179717
Approved by: https://github.com/jerrymannil, https://github.com/jeffdaily
(cherry picked from commit 4ee8611)