Skip to content

Commit ef3ea3c

Browse files
committed
chore: cleanup the tests + ci fix
1 parent ee533e2 commit ef3ea3c

15 files changed

+436
-7100
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,5 @@ jobs:
5959
ctest -j ${{env.parallel_processes}} -T memcheck -C ${{matrix.build_type}} --test-dir submission_25_05_01 --output-on-failure
6060
ctest -j ${{env.parallel_processes}} -T memcheck -C ${{matrix.build_type}} --test-dir submission_25_05_08 --output-on-failure
6161
ctest -j ${{env.parallel_processes}} -T memcheck -C ${{matrix.build_type}} --test-dir submission_25_05_15 --output-on-failure
62-
ctest -j ${{env.parallel_processes}} -T memcheck -C ${{matrix.build_type}} --output-on-failure
62+
ctest -j ${{env.parallel_processes}} -T memcheck -C ${{matrix.build_type}} --output-on-failure -E "^Test gemm generation"
6363

CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,12 @@ set(TEST_FILES
121121

122122
set(TEST_KERNELS
123123
matmul.test.h
124+
matmul.test.cpp
124125
matmul_16_6_1.test.cpp
125126
matmul_16_6_k.test.cpp
126127
matmul_16m_4n_k.test.cpp
127128
matmul_16m_lt4nRest_k.test.cpp
128-
matmul_16mRest_4nRest_k.n1.test.cpp
129-
matmul_16mRest_4nRest_k.n2.test.cpp
130-
matmul_16mRest_4nRest_k.n3.test.cpp
129+
matmul_16mRest_4nRest_k.test.cpp
131130
)
132131

133132
set(TEST_ARM_INSTRUCTION_FILES

src/main/kernels/matmul_16mRest_4nRest_k.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ void mini_jit::kernels::matmul_16mRest_4nRest_k(mini_jit::Kernel &kernel, const
1919
release_assert(n_loop_rest <= 3, "Cannot create a matrix with a rest of n larger than 3!");
2020

2121
// Idea: Division of the matrix into sub-matrices and calculated in the following order.
22-
// =====================================================
23-
// | | |
24-
// | | 2. matmul_16m_lt4nRest_k |
25-
// | 1. matmul_16mRest_4n_k | |
26-
// | |--------------------------|
27-
// | | 3. Rest of m and n |
28-
// =====================================================
22+
// N dimension
23+
// ←---------------------------------------------------→
24+
// ===================================================== ↑
25+
// | | | |
26+
// | | 2. matmul_16m_lt4nRest_k | |
27+
// | 1. matmul_16mRest_4n_k | | | M dimension
28+
// | |--------------------------| |
29+
// | | 3. Rest of m and n | |
30+
// ===================================================== ↓
2931

3032
kernel.add({
3133
// /**

src/test/BaseGeneration.test.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class GenerationTest
5959
void verify_matmul(const float *__restrict__ expected, const float *__restrict__ result, uint32_t size);
6060

6161
public:
62+
GenerationTest() = delete;
6263
GenerationTest(uint32_t M, uint32_t N, uint32_t K);
6364
GenerationTest(uint32_t M, uint32_t N, uint32_t K, uint32_t BatchSize);
6465
~GenerationTest();

src/test/kernels/matmul.test.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#include "matmul.test.h"
2+
void GemmMxNxKxBatchTestFixture::_RunTest(const uint32_t lda, const uint32_t ldb, const uint32_t ldc, const uint32_t batch_stride_a,
3+
const uint32_t batch_stride_b)
4+
{
5+
if (native_kernel.get_size() <= 0)
6+
{
7+
INFO("The kernel should contain instructions before the test is executed.");
8+
REQUIRE(native_kernel.get_size() > 0);
9+
}
10+
11+
// Generate executable kernel
12+
native_kernel.set_kernel();
13+
mini_jit::Brgemm::kernel_t kernel = reinterpret_cast<mini_jit::Brgemm::kernel_t>(
14+
const_cast<void *>(native_kernel.get_kernel())); // Properly cast from const void* to kernel_t
15+
16+
// Run matmuls
17+
kernel(matrix_a, matrix_b, matrix_c, lda, ldb, ldc, batch_stride_a, batch_stride_b);
18+
naive_matmul_M_N_K_Batch(matrix_a, matrix_b, matrix_c_verify, lda, ldb, ldc, batch_stride_a, batch_stride_b);
19+
20+
verify_matmul(matrix_c_verify, matrix_c, M * N);
21+
};
22+
23+
void GemmMxNxKxBatchTestFixture::fill_random_matrix(float *matrix, uint32_t size)
24+
{
25+
std::srand(std::time(0));
26+
for (size_t i = 0; i < size; i++)
27+
{
28+
matrix[i] = (static_cast<float>(std::rand())) / (static_cast<float>(std::rand()));
29+
}
30+
}
31+
32+
void GemmMxNxKxBatchTestFixture::fill_counting_matrix(float *matrix, uint32_t size)
33+
{
34+
for (size_t i = 0; i < size; i++)
35+
{
36+
matrix[i] = i;
37+
}
38+
}
39+
40+
void GemmMxNxKxBatchTestFixture::naive_matmul_M_N_K_Batch(const float *__restrict__ a, const float *__restrict__ b, float *__restrict__ c,
41+
int64_t lda, int64_t ldb, int64_t ldc, int64_t batch_stride_a,
42+
int64_t batch_stride_b)
43+
{
44+
for (size_t iB = 0; iB < BatchSize; iB++)
45+
{
46+
for (size_t iM = 0; iM < M; iM++)
47+
{
48+
for (size_t iN = 0; iN < N; iN++)
49+
{
50+
for (size_t iK = 0; iK < K; ++iK)
51+
{
52+
c[iM + iN * ldc] += a[iM + iK * lda + iB * batch_stride_a] * b[iK + iN * ldb + iB * batch_stride_b];
53+
}
54+
}
55+
}
56+
}
57+
}
58+
59+
void GemmMxNxKxBatchTestFixture::verify_matmul(const float *__restrict__ expected, const float *__restrict__ result, uint32_t size)
60+
{
61+
for (size_t i = 0; i < size; i++)
62+
{
63+
CAPTURE(i, result[i], expected[i]);
64+
REQUIRE_THAT(result[i], Catch::Matchers::WithinRel(expected[i]));
65+
}
66+
}
67+
68+
GemmMxNxKxBatchTestFixture::GemmMxNxKxBatchTestFixture(uint32_t M, uint32_t N, uint32_t K, uint32_t BatchSize)
69+
: M(M), N(N), K(K), BatchSize(BatchSize)
70+
{
71+
72+
matrix_a = new float[M * K * BatchSize];
73+
matrix_b = new float[K * N * BatchSize];
74+
matrix_c = new float[M * N];
75+
matrix_c_verify = new float[M * N];
76+
}
77+
78+
GemmMxNxKxBatchTestFixture::~GemmMxNxKxBatchTestFixture()
79+
{
80+
delete[] matrix_a;
81+
delete[] matrix_b;
82+
delete[] matrix_c;
83+
delete[] matrix_c_verify;
84+
}
85+
86+
void GemmMxNxKxBatchTestFixture::SetUp(TestInfill fillType)
87+
{
88+
switch (fillType)
89+
{
90+
case TestInfill::Random:
91+
fill_random_matrix(matrix_a, M * K * BatchSize);
92+
fill_random_matrix(matrix_b, K * N * BatchSize);
93+
fill_random_matrix(matrix_c, M * N);
94+
break;
95+
case TestInfill::Counting:
96+
fill_counting_matrix(matrix_a, M * K * BatchSize);
97+
fill_counting_matrix(matrix_b, K * N * BatchSize);
98+
fill_counting_matrix(matrix_c, M * N);
99+
break;
100+
default:
101+
FAIL("Undefined infill type found.");
102+
break;
103+
}
104+
105+
std::copy(matrix_c, matrix_c + M * N, matrix_c_verify);
106+
}
107+
108+
GemmMxNxKTestFixture::GemmMxNxKTestFixture(uint32_t M, uint32_t N, uint32_t K) : GemmMxNxKxBatchTestFixture(M, N, K, 1)
109+
{
110+
}
111+
112+
GemmMxNxKTestFixture::~GemmMxNxKTestFixture()
113+
{
114+
}

src/test/kernels/matmul.test.h

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -108,46 +108,71 @@ enum class TestInfill
108108
Counting,
109109
};
110110

111-
template <uint32_t TMdim, uint32_t TNdim, uint32_t TKdim, uint32_t TBatchDim> class GemmMxNxKxBatchTestFixture
111+
class GemmMxNxKxBatchTestFixture
112112
{
113+
private:
114+
uint32_t M;
115+
uint32_t N;
116+
uint32_t K;
117+
uint32_t BatchSize;
118+
float *matrix_a;
119+
float *matrix_b;
120+
float *matrix_c;
121+
float *matrix_c_verify;
122+
123+
/**
124+
* @brief Fills the given matrix with random values.
125+
*
126+
* @param matrix The matrix to fill.
127+
* @param size The total size of the matrix.
128+
*/
129+
void fill_random_matrix(float *matrix, uint32_t size);
130+
131+
/**
132+
* @brief Fills the given matrix with counting values starting from 0.
133+
*
134+
* @param matrix The matrix to fill.
135+
* @param size The total size of the matrix.
136+
*/
137+
void fill_counting_matrix(float *matrix, uint32_t size);
138+
139+
/**
140+
* @brief Does a naive matmul for verification usage.
141+
*
142+
* @param a The a matrix.
143+
* @param b The b matrix.
144+
* @param c The c matrix.
145+
* @param lda The leading dimension of matrix a.
146+
* @param ldb The leading dimension of matrix b.
147+
* @param ldc The leading dimension of matrix c.
148+
* @param batch_stride_a The batch stride of matrix a.
149+
* @param batch_stride_b The batch stride of matrix b.
150+
*/
151+
void naive_matmul_M_N_K_Batch(const float *__restrict__ a, const float *__restrict__ b, float *__restrict__ c, int64_t lda, int64_t ldb,
152+
int64_t ldc, int64_t batch_stride_a, int64_t batch_stride_b);
153+
154+
/**
155+
* @brief Compares the two matrices by comparing each values.
156+
*
157+
* @param expected The matrix results that are expected.
158+
* @param result The actual matrix values.
159+
* @param size The total size of the matrix.
160+
*/
161+
void verify_matmul(const float *__restrict__ expected, const float *__restrict__ result, uint32_t size);
162+
113163
public:
114-
float matrix_a[TMdim * TKdim * TBatchDim];
115-
float matrix_b[TKdim * TNdim * TBatchDim];
116-
float matrix_c[TMdim * TNdim];
117-
float matrix_c_verify[TMdim * TNdim];
118-
const uint32_t lda = TMdim;
119-
const uint32_t ldb = TKdim;
120-
const uint32_t ldc = TMdim;
121-
const uint32_t batch_stride_a = TMdim * TKdim;
122-
const uint32_t batch_stride_b = TKdim * TNdim;
123164
mini_jit::Kernel native_kernel;
124165

166+
GemmMxNxKxBatchTestFixture() = delete;
167+
GemmMxNxKxBatchTestFixture(uint32_t M, uint32_t N, uint32_t K, uint32_t BatchSize);
168+
~GemmMxNxKxBatchTestFixture();
169+
125170
/**
126171
* @brief Set up the test fixture object.
127172
*
128173
* @param fillType Fills the matrices with the given infill type.
129174
*/
130-
void SetUp(TestInfill fillType)
131-
{
132-
switch (fillType)
133-
{
134-
case TestInfill::Random:
135-
fill_random_matrix(matrix_a);
136-
fill_random_matrix(matrix_b);
137-
fill_random_matrix(matrix_c);
138-
break;
139-
case TestInfill::Counting:
140-
fill_counting_matrix(matrix_a);
141-
fill_counting_matrix(matrix_b);
142-
fill_counting_matrix(matrix_c);
143-
break;
144-
default:
145-
FAIL("Undefined infill type found.");
146-
break;
147-
}
148-
149-
copy_matrix(matrix_c, matrix_c_verify);
150-
}
175+
void SetUp(TestInfill fillType);
151176

152177
/**
153178
* @brief Executes the Test von an BRGemm with the given input.
@@ -173,36 +198,20 @@ template <uint32_t TMdim, uint32_t TNdim, uint32_t TKdim, uint32_t TBatchDim> cl
173198
* @param br_stride_a: stride between two A matrices (in elements, not bytes).
174199
* @param br_stride_b: stride between two B matrices (in elements, not bytes).
175200
*/
176-
void _RunTest(const uint32_t lda, const uint32_t ldb, const uint32_t ldc, const uint32_t batch_stride_a, const uint32_t batch_stride_b)
177-
{
178-
if (native_kernel.get_size() <= 0)
179-
{
180-
INFO("The kernel should contain instructions before the test is executed.");
181-
REQUIRE(native_kernel.get_size() > 0);
182-
}
183-
184-
// Generate executable kernel
185-
native_kernel.set_kernel();
186-
mini_jit::Brgemm::kernel_t kernel = reinterpret_cast<mini_jit::Brgemm::kernel_t>(
187-
const_cast<void *>(native_kernel.get_kernel())); // Properly cast from const void* to kernel_t
188-
189-
// Run matmuls
190-
kernel(matrix_a, matrix_b, matrix_c, lda, ldb, ldc, batch_stride_a, batch_stride_b);
191-
naive_matmul_M_N_K_Batch<TMdim, TNdim, TKdim, TBatchDim>(matrix_a, matrix_b, matrix_c_verify, lda, ldb, ldc, batch_stride_a,
192-
batch_stride_b);
193-
194-
verify_matmul(matrix_c_verify, matrix_c);
195-
}
201+
void _RunTest(const uint32_t lda, const uint32_t ldb, const uint32_t ldc, const uint32_t batch_stride_a, const uint32_t batch_stride_b);
196202
};
197203

198-
template <uint32_t TMdim, uint32_t TNdim, uint32_t TKdim>
199-
class GemmMxNxKTestFixture : public GemmMxNxKxBatchTestFixture<TMdim, TNdim, TKdim, 1>
204+
class GemmMxNxKTestFixture : public GemmMxNxKxBatchTestFixture
200205
{
201206

202207
void RunTest(const uint32_t lda, const uint32_t ldb, const uint32_t ldc, const uint32_t batch_stride_a,
203208
const uint32_t batch_stride_b) = delete; // delete so not visible in a GemmMxNxKTestFixture object.
204209

205210
public:
211+
GemmMxNxKTestFixture() = delete;
212+
GemmMxNxKTestFixture(uint32_t M, uint32_t N, uint32_t K);
213+
~GemmMxNxKTestFixture();
214+
206215
/**
207216
* @brief Executes the Test von an BRGemm with the given input.
208217
*
@@ -212,7 +221,7 @@ class GemmMxNxKTestFixture : public GemmMxNxKxBatchTestFixture<TMdim, TNdim, TKd
212221
*/
213222
void RunTest(const uint32_t lda, const uint32_t ldb, const uint32_t ldc)
214223
{
215-
GemmMxNxKxBatchTestFixture<TMdim, TNdim, TKdim, 1>::_RunTest(lda, ldb, ldc, 0, 0);
224+
GemmMxNxKxBatchTestFixture::_RunTest(lda, ldb, ldc, 0, 0);
216225
}
217226
};
218227

src/test/kernels/matmul_16_6_1.test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55

66
TEST_CASE("Test matmul_16_6_1 jited gemm correctness random data", "[jit][correctness][gemm]")
77
{
8-
GemmMxNxKTestFixture<16, 6, 1> gemmTest;
8+
GemmMxNxKTestFixture gemmTest(16, 6, 1);
99
gemmTest.SetUp(TestInfill::Random);
1010
mini_jit::kernels::matmul_16_6_1(gemmTest.native_kernel);
1111
gemmTest.RunTest(16, 1, 16);
1212
}
1313

1414
TEST_CASE("Test matmul_16_6_1 jited gemm correctness counting data", "[jit][correctness][gemm]")
1515
{
16-
GemmMxNxKTestFixture<16, 6, 1> gemmTest;
16+
GemmMxNxKTestFixture gemmTest(16, 6, 1);
1717
gemmTest.SetUp(TestInfill::Counting);
1818
mini_jit::kernels::matmul_16_6_1(gemmTest.native_kernel);
1919
gemmTest.RunTest(16, 1, 16);

src/test/kernels/matmul_16_6_k.test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,31 @@
55

66
TEST_CASE("Test matmul_16_6_k (M=16, N=6, K=1) jited gemm correctness random data", "[jit][correctness][gemm]")
77
{
8-
GemmMxNxKTestFixture<16, 6, 1> gemmTest;
8+
GemmMxNxKTestFixture gemmTest(16, 6, 1);
99
gemmTest.SetUp(TestInfill::Random);
1010
mini_jit::kernels::matmul_16_6_k(gemmTest.native_kernel, 1);
1111
gemmTest.RunTest(16, 1, 16);
1212
}
1313

1414
TEST_CASE("Test matmul_16_6_k (M=16, N=6, K=1) jited gemm correctness counting data", "[jit][correctness][gemm]")
1515
{
16-
GemmMxNxKTestFixture<16, 6, 1> gemmTest;
16+
GemmMxNxKTestFixture gemmTest(16, 6, 1);
1717
gemmTest.SetUp(TestInfill::Counting);
1818
mini_jit::kernels::matmul_16_6_k(gemmTest.native_kernel, 1);
1919
gemmTest.RunTest(16, 1, 16);
2020
}
2121

2222
TEST_CASE("Test matmul_16_6_k (M=16, N=6, K=128) jited gemm correctness random data", "[jit][correctness][gemm]")
2323
{
24-
GemmMxNxKTestFixture<16, 6, 128> gemmTest;
24+
GemmMxNxKTestFixture gemmTest(16, 6, 128);
2525
gemmTest.SetUp(TestInfill::Random);
2626
mini_jit::kernels::matmul_16_6_k(gemmTest.native_kernel, 128);
2727
gemmTest.RunTest(16, 128, 16);
2828
}
2929

3030
TEST_CASE("Test matmul_16_6_k (M=16, N=6, K=128) jited gemm correctness counting data", "[jit][correctness][gemm]")
3131
{
32-
GemmMxNxKTestFixture<16, 6, 128> gemmTest;
32+
GemmMxNxKTestFixture gemmTest(16, 6, 128);
3333
gemmTest.SetUp(TestInfill::Counting);
3434
mini_jit::kernels::matmul_16_6_k(gemmTest.native_kernel, 128);
3535
gemmTest.RunTest(16, 128, 16);

0 commit comments

Comments
 (0)