-
Notifications
You must be signed in to change notification settings - Fork 23
Triton norms dispatch refactor #305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
7c72ee0
45e6236
5b2ea1c
077b8fc
0011e5f
5ada1bd
26298c8
cc02444
18eb6e7
a72b507
a64b5f1
1d2554c
fd59057
bbd4240
92aecf2
12bb156
1925039
6f9b6c5
dc3ed87
24fbcc4
0d6d00f
499d14b
c526c8d
26cd12c
d031266
8a5a786
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,29 +1,29 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # License for AMD contributions = MIT. See LICENSE for more information | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import math | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from functools import partial | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from itertools import product | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformer_engine.pytorch.fp8 import FP8GlobalStateManager | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformer_engine.pytorch import cpp_extensions as tex | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformer_engine.pytorch.triton_kernels.norm_common import get_ln_sm_margin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformer_engine.pytorch.triton_kernels.utils import get_ln_sm_margin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformer_engine.pytorch.triton_kernels.common import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch_dtype_to_te_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| te_dtype_to_torch_dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformer_engine.pytorch.triton_kernels.rmsnorm import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
ipanfilo marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| te_rmsnorm_bwd_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| te_rmsnorm_fwd_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformer_engine.pytorch.triton_kernels.layernorm import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformer_engine.pytorch.triton_kernels.norms_common import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| te_layernorm_bwd_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| te_layernorm_fwd_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| te_rmsnorm_bwd_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| te_rmsnorm_fwd_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from test_common import dtype_tols, te_compare_results, str_to_torch_dtype, fill_uniform | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -406,11 +406,11 @@ def _compare_output_tensors( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| quantization, fp8_dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tols = dtype_tols(out_triton.dtype if quantization is None else fp8_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func = partial(te_compare_results, **tols, use_torch_semantics=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func = partial(te_compare_results, **tols, use_torch_semantics=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dq_out_triton = out_triton.dequantize() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dq_out_hip = out_hip.dequantize() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=dq_out_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=dq_out_hip, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"Output does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -428,7 +428,7 @@ def _compare_output_tensors( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not out_hip._transpose_invalid: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # The transpose data are generally uint8 so we must convert | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # them for floating point comparison. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=out_triton._transpose.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)).to(torch.float32), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=out_hip._transpose.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)).to(torch.float32), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"Output transpose does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -437,67 +437,66 @@ def _compare_output_tensors( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif quantization == "mxfp8": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(out_triton, MXFP8Tensor): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Expected a MXFP8Tensor but got {type(out_triton)} instead.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO(micky774): Figure out if we need to apply the same view | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # trick to MXFP8 data as we do to FP8 transpose data. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # I suspect not. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if out_hip._rowwise_data is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=out_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=out_hip, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"Output rowwise data does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out_triton._rowwise_data = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert out_triton._rowwise_data is not None, "Expected rowwise data." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert out_triton._rowwise_data is None, "Expected no rowwise data." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # We use higher precision for the scales | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if quantization == "fp8": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=out_triton._scale_inv, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=out_hip._scale_inv, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"Output scale inverse does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif quantization == "mxfp8": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_rscale_triton = out_triton._rowwise_scale_inv is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_rscale_hip = out_hip._rowwise_scale_inv is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # The scale_inv values may differ slightly, but will still dequantize close enough to | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # pass the earlier comparisons. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func = partial(te_compare_results, atol=1, rtol=0, use_torch_semantics=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For mxfp8 data and scale inv comparison, we can reuse the same logic in cpp gtest: TransformerEngine/tests/cpp/test_common.cu Line 730 in 0dfee56
TransformerEngine/tests/cpp/operator/test_cast_mxfp8.cu Lines 331 to 355 in 0dfee56
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We essentially already do this implicitly by relying on the dequantization of the |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # The MXFP8 tensors carry their scale_inv values in a padded | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # format, hence we must omit the padded values. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_shape = out_triton.shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| unpad_rscale_inv_shape = (math.prod(input_shape[:-1]), input_shape[-1] // MXFP8_BLOCK_SCALING_SIZE) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if has_rscale_triton != has_rscale_hip: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg = "Expected rowwise scale to " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if has_rscale_hip: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg += "not " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg += "be None." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(msg) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if has_rscale_triton: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=out_triton._rowwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=out_hip._rowwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=out_triton._rowwise_scale_inv[:unpad_rscale_inv_shape[0], :unpad_rscale_inv_shape[1]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=out_hip._rowwise_scale_inv[:unpad_rscale_inv_shape[0], :unpad_rscale_inv_shape[1]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"Output rowwise scale inverse does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_cscale_triton = out_triton._columnwise_scale_inv is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_cscale_hip = out_hip._columnwise_scale_inv is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if has_cscale_triton != has_cscale_hip: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg = "Expected columnwwise scale to " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg = "Expected columnwise scale to " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if has_cscale_hip: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg += "not " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg += "be None." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(msg) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if has_cscale_triton: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=out_triton._columnwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=out_hip._columnwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=out_triton._columnwise_scale_inv[:unpad_rscale_inv_shape[1], :unpad_rscale_inv_shape[0]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=out_hip._columnwise_scale_inv[:unpad_rscale_inv_shape[1], :unpad_rscale_inv_shape[0]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"Output columnwise scale inverse does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _compare_quantizers( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quantizer_triton, quantizer_hip, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quantization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if quantization is None: return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if quantizer_triton.dtype != quantizer_hip.dtype: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Expected matching quantizer dtypes, but got " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -511,12 +510,12 @@ def _compare_quantizers( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Expected matching quantizer {usage} but got {qt_usage=} != {qh_usage=}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if quantization == "fp8": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=quantizer_triton.scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=quantizer_hip.scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"Quantizer scale does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=quantizer_triton.amax, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=quantizer_hip.amax, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"Quantizer amax does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -529,15 +528,15 @@ def _compare_stat_tensors( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| norm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # We use higher precision for the remaining outputs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=rsigma_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=rsigma_hip, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"rsigma does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if norm == "layer": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=mu_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=mu_hip, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"mu does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -579,20 +578,20 @@ def _compare_bwd_tensors( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| dbeta_triton, dbeta_hip, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| norm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func = partial(te_compare_results, atol=1.5e-4, rtol=1e-4, use_torch_semantics=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func = partial(te_compare_results, atol=1.5e-4, rtol=1e-4, use_torch_semantics=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=dx_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=dx_hip, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"dx does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=dgamma_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=dgamma_hip, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"dgamma does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if norm == "layer": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compare_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actual=dbeta_triton, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected=dbeta_hip, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg=lambda msg: f"dbeta does not match triton <-> hip\n\n{msg}\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
ipanfilo marked this conversation as resolved.
Show resolved
Hide resolved
|
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copyright date
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, along with a few others I missed. |
ipanfilo marked this conversation as resolved.
Show resolved
Hide resolved
|
ipanfilo marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.