-
Notifications
You must be signed in to change notification settings - Fork 23
Remove padding from scales for hipBLASlt calls #442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
1a24ff2
5bbfb4b
cbbd027
9f8f611
9a5980f
ff72142
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,7 +41,7 @@ | |
| Float8Quantizer, | ||
| Float8Tensor, | ||
| ) | ||
| from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor | ||
| from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer | ||
| from transformer_engine.pytorch.tensor.utils import replace_raw_data | ||
| from transformer_engine.pytorch.distributed import checkpoint | ||
| from utils import ModelConfig | ||
|
|
@@ -913,6 +913,78 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): | |
| ) | ||
| torch.cuda.synchronize() | ||
|
|
||
| @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) | ||
| @pytest.mark.parametrize("N", [32]) | ||
| @pytest.mark.parametrize("K", [128]) | ||
| @pytest.mark.parametrize("M", [32]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better use non multiple of 32 to test this path is unpadding
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We require block sizes of 32 at the python level, so not possible to do a non-multiple. We are padding scales, so we will see a rowwise scale of (1,4) padded to (128,4), and a colwise scale of (4,1) being padded to (4,128). |
||
| @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) | ||
| def test_sanity_mxfp8_gemm_with_padding(N, K, M, datatype): | ||
| """Test the unpadding functionality in rocm""" | ||
| dtype = tex.DType.kFloat8E4M3 | ||
| quantizer = MXFP8Quantizer(dtype) | ||
|
|
||
| input_dtype = torch.randn(M, K, device="cuda", dtype=datatype) | ||
| weight_dtype = torch.randn(N, K, device="cuda", dtype=datatype) | ||
|
|
||
| input_data = quantizer.make_empty((M, K), device="cuda") | ||
| weight_data = quantizer.make_empty((N, K), device="cuda") | ||
|
|
||
| quantizer.update_quantized(input_dtype, input_data) | ||
| quantizer.update_quantized(weight_dtype, weight_data) | ||
|
|
||
| out_ref = general_gemm( | ||
| weight_data, | ||
| input_data, | ||
| get_workspace(), | ||
| datatype, | ||
| bias=None, | ||
| use_split_accumulator=False, | ||
| ) | ||
| torch.cuda.synchronize() | ||
|
|
||
| row_scale_inv = input_data._rowwise_scale_inv | ||
| rows, cols = row_scale_inv.shape | ||
| row_padded_scale_inv = torch.zeros((128, 4), dtype=row_scale_inv.dtype, device="cuda") | ||
| row_padded_scale_inv[:rows, :cols] = row_scale_inv | ||
|
|
||
| col_scale_inv = input_data._columnwise_scale_inv | ||
| rows, cols = col_scale_inv.shape | ||
| col_padded_scale_inv = torch.zeros((4, 128), dtype=col_scale_inv.dtype, device="cuda") | ||
| col_padded_scale_inv[:rows, :cols] = col_scale_inv | ||
|
|
||
|
|
||
| input_padded = MXFP8Tensor( | ||
| shape=input_data.shape, | ||
| rowwise_data=input_data._rowwise_data.clone(), | ||
| rowwise_scale_inv=row_padded_scale_inv, | ||
| columnwise_data=input_data._columnwise_data.clone(), | ||
| columnwise_scale_inv=col_padded_scale_inv, | ||
| fp8_dtype=tex.DType.kFloat8E4M3, | ||
| quantizer=quantizer, | ||
| dtype=datatype | ||
| ) | ||
|
|
||
| out_pass1 = general_gemm( | ||
| weight_data, | ||
| input_padded, | ||
| get_workspace(), | ||
| datatype, | ||
| bias=None, | ||
| use_split_accumulator=False | ||
| ) | ||
| torch.cuda.synchronize() | ||
|
|
||
| assert row_scale_inv.shape == input_padded._rowwise_scale_inv.shape, \ | ||
| ("Shape mismatch in rowwise scales") | ||
| assert col_scale_inv.shape == input_padded._columnwise_scale_inv.shape, \ | ||
| ("Shape mismatch in colwise scales") | ||
| torch.testing.assert_close(row_scale_inv, input_padded._rowwise_scale_inv, | ||
| rtol=1e-7, atol=1e-7, msg="rowwise scale mismatch") | ||
| torch.testing.assert_close(col_scale_inv, input_padded._columnwise_scale_inv, | ||
| rtol=1e-7, atol=1e-7, msg="colwise scale mismatch") | ||
| torch.testing.assert_close(out_pass1[0], out_ref[0], | ||
| rtol=1e-2, atol=1e-2, msg="GEMM output mismatch") | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) | ||
| def test_replace_raw_data_for_float8tensor(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it hipblasLt limitation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these are the values that hipblastlt team provided to us. I tested just in case, but nothing smaller that 128 works for k.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is 32x128x32 config needed with 16x128x16 then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say it makes sense to keep. This allows us to test a TE acceptable size with 32 while also ensuring unpadding and hipBLASlt is working with 16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case I'd change 32x128x32 to 32x128x16 to test they work together