Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions tools/clang/lib/Headers/hlsl/dx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
ComponentEnum MatrixDT>
// clang-format off
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElTy, K> >::type
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElTy, M> >::type
// clang-format on
Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
vector<InputElTy, K> Vec) {
Expand All @@ -479,29 +479,29 @@ Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
// clang-format off
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElTy, K> >::type
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElTy, M> >::type
// clang-format on
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
vector<InputElTy, M> Vec, vector<BiasElTy, K> Bias) {
vector<OutputElTy, K> Result;
vector<InputElTy, K> Vec, vector<BiasElTy, M> Bias) {
vector<OutputElTy, M> Result;
__builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle,
hlsl::is_signed<OutputElTy>::value,
Vec, MatrixDT, Bias, MatrixDT);
return Result;
}

template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
ComponentEnum MatrixDT>
// clang-format off
typename hlsl::enable_if<
InterpretedVector<InputElTy, VecM, InputInterp>::Size == M,
vector<OutputElTy, K> >::type
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
vector<OutputElTy, M> >::type
// clang-format on
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
InterpretedVector<InputElTy, VecM, InputInterp> InterpVec,
vector<BiasElTy, K> Bias) {
vector<OutputElTy, K> Result;
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
vector<BiasElTy, M> Bias) {
vector<OutputElTy, M> Result;
__builtin_LinAlg_MatrixVectorMultiplyAdd(
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value,
InterpVec.Data, InterpVec.Interpretation, Bias, MatrixDT);
Expand All @@ -512,35 +512,35 @@ template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
// clang-format off
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
vector<OutputElTy, K> >::type
vector<OutputElTy, M> >::type
// clang-format on
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
vector<InputElTy, M> Vec, VectorRef<BiasElTy, K> BiasRef) {
vector<InputElTy, K> Vec, VectorRef<BiasElTy, M> BiasRef) {
using BiasVecTy =
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, K>;
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, M>;
BiasVecTy BiasVec = BiasRef.Buf.template Load<BiasVecTy>(BiasRef.Offset);
vector<OutputElTy, K> Result;
vector<OutputElTy, M> Result;
__builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle,
hlsl::is_signed<OutputElTy>::value,
Vec, MatrixDT, BiasVec, BiasElTy);
return Result;
}

template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
ComponentEnum MatrixDT>
// clang-format off
typename hlsl::enable_if<
InterpretedVector<InputElTy, VecM, InputInterp>::Size == M,
vector<OutputElTy, K> >::type
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
vector<OutputElTy, M> >::type
// clang-format on
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
InterpretedVector<InputElTy, VecM, InputInterp> InterpVec,
VectorRef<BiasElTy, K> BiasRef) {
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
VectorRef<BiasElTy, M> BiasRef) {
using BiasVecTy =
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, K>;
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, M>;
BiasVecTy BiasVec = BiasRef.Buf.template Load<BiasVecTy>(BiasRef.Offset);
vector<OutputElTy, K> Result;
vector<OutputElTy, M> Result;
__builtin_LinAlg_MatrixVectorMultiplyAdd(
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value,
InterpVec.Data, InterpVec.Interpretation, BiasVec, BiasElTy);
Expand Down
38 changes: 20 additions & 18 deletions tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,60 @@
#include <dx/linalg.h>
using namespace dx::linalg;

using MatrixATy = Matrix<ComponentType::F16, 8, 8, MatrixUse::A, MatrixScope::Thread>;
using MatrixATy = Matrix<ComponentType::F16, 8, 4, MatrixUse::A, MatrixScope::Thread>;
using MatrixAccumTy = Matrix<ComponentType::F16, 8, 8, MatrixUse::Accumulator, MatrixScope::Thread>;

ByteAddressBuffer BAB : register(t0);

[numthreads(4, 4, 4)]
void main(uint ID : SV_GroupID) {

// CHECK: %[[MAT1:.*]] = call %dx.types.LinAlgMatrixC8M8N8U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M8N8U0S0(
// CHECK: %[[MAT1:.*]] = call %dx.types.LinAlgMatrixC8M8N4U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M8N4U0S0(
// CHECK-SAME: i32 -2147483634, %dx.types.Handle %2, i32 0, i32 8, i32 1, i32 2)
// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
MatrixATy Mat1 = MatrixATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0, 8);

vector<half, 8> vec1 = 10.3f;
vector<half, 4> vec1 = 10.3f;

// CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N8U0S0.v8f16(i32 -2147483623,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %3, i1 true, <8 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
// CHECK-SAME: half 0xH4926, half 0xH4926, half 0xH4926, half 0xH4926, half 0xH4926>, i32 8)
// CHECK-SAME: ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation)
// CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N4U0S0.v4f16(i32 -2147483623,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %3, i1 true, <4 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
// CHECK-SAME: half 0xH4926>, i32 8) ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation)
vector<half, 8> vec2 = Multiply<half>(Mat1, vec1);

// CHECK: %[[VEC3:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
// CHECK-SAME: half 0xH4926, half 0xH4926, half 0xH4926, half 0xH4926, half 0xH4926>, i32 8, <8 x half> %[[VEC2]], i32 8)
// CHECK: %[[VEC3:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
// CHECK-SAME: half 0xH4926>, i32 8, <8 x half> %[[VEC2]], i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 8> vec3 = MultiplyAdd<half>(Mat1, vec1, vec2);

// CHECK: %[[VEC4:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC2]], i32 8, <8 x half> %[[VEC3]], i32 8)
// CHECK: %[[VEC20:.*]] = shufflevector
vector<half, 4> vec20 = (vector<half, 4>)vec2;

// CHECK: %[[VEC4:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x half> %[[VEC3]], i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
InterpretedVector<half, 8, ComponentType::F16> interpVec2 = MakeInterpretedVector<ComponentType::F16>(vec2);
InterpretedVector<half, 4, ComponentType::F16> interpVec2 = MakeInterpretedVector<ComponentType::F16>(vec20);
vector<half, 8> vec4 = MultiplyAdd<half>(Mat1, interpVec2, vec3);

// CHECK: %[[RAWLOAD:.*]] = call %dx.types.ResRet.v8i16 @dx.op.rawBufferVectorLoad.v8i16(i32 303,
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 4096, i32 undef, i32 2) ; RawBufferVectorLoad(buf,index,elementOffset,alignment)

// CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0

// CHECK: %[[VEC5:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8i16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC3]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
// CHECK: %[[VEC5:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8i16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
// CHECK-SAME:; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
VectorRef<ComponentType::I16, 8> memBias = {BAB, 4096};
vector<half, 8> vec5 = MultiplyAdd<half>(Mat1, vec3, memBias);
vector<half, 8> vec5 = MultiplyAdd<half>(Mat1, interpVec2, memBias);

// CHECK: %[[RAWLOAD:.*]] = call %dx.types.ResRet.v8i16 @dx.op.rawBufferVectorLoad.v8i16(i32 303,
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 4096, i32 undef, i32 2)
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)

// CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0

// CHECK: %[[VEC6:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8i16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC2]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
// CHECK: %[[VEC6:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8i16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 8> vec6 = MultiplyAdd<half>(Mat1, interpVec2, memBias);

Expand Down
4 changes: 2 additions & 2 deletions utils/hct/gen_intrin_main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,8 @@ uint [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout();
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(ref LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(ref LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<> input, in uint inputInterp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<> input, in uint inputInterp, in numeric<> bias, in uint biasInterp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<c2> input, in uint inputInterp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<c2> input, in uint inputInterp, in numeric<c> bias, in uint biasInterp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<> vecA, in numeric<> vecB);
Expand Down
Loading