@@ -108,46 +108,71 @@ enum class TestInfill
108108 Counting,
109109};
110110
111- template < uint32_t TMdim, uint32_t TNdim, uint32_t TKdim, uint32_t TBatchDim> class GemmMxNxKxBatchTestFixture
111+ class GemmMxNxKxBatchTestFixture
112112{
113+ private:
114+ uint32_t M;
115+ uint32_t N;
116+ uint32_t K;
117+ uint32_t BatchSize;
118+ float *matrix_a;
119+ float *matrix_b;
120+ float *matrix_c;
121+ float *matrix_c_verify;
122+
123+ /* *
124+ * @brief Fills the given matrix with random values.
125+ *
126+ * @param matrix The matrix to fill.
127+ * @param size The total size of the matrix.
128+ */
129+ void fill_random_matrix (float *matrix, uint32_t size);
130+
131+ /* *
132+ * @brief Fills the given matrix with counting values starting from 0.
133+ *
134+ * @param matrix The matrix to fill.
135+ * @param size The total size of the matrix.
136+ */
137+ void fill_counting_matrix (float *matrix, uint32_t size);
138+
139+ /* *
140+ * @brief Does a naive matmul for verification usage.
141+ *
142+ * @param a The a matrix.
143+ * @param b The b matrix.
144+ * @param c The c matrix.
145+ * @param lda The leading dimension of matrix a.
146+ * @param ldb The leading dimension of matrix b.
147+ * @param ldc The leading dimension of matrix c.
148+ * @param batch_stride_a The batch stride of matrix a.
149+ * @param batch_stride_b The batch stride of matrix b.
150+ */
151+ void naive_matmul_M_N_K_Batch (const float *__restrict__ a, const float *__restrict__ b, float *__restrict__ c, int64_t lda, int64_t ldb,
152+ int64_t ldc, int64_t batch_stride_a, int64_t batch_stride_b);
153+
154+ /* *
155+ * @brief Compares the two matrices by comparing each values.
156+ *
157+ * @param expected The matrix results that are expected.
158+ * @param result The actual matrix values.
159+ * @param size The total size of the matrix.
160+ */
161+ void verify_matmul (const float *__restrict__ expected, const float *__restrict__ result, uint32_t size);
162+
113163public:
114- float matrix_a[TMdim * TKdim * TBatchDim];
115- float matrix_b[TKdim * TNdim * TBatchDim];
116- float matrix_c[TMdim * TNdim];
117- float matrix_c_verify[TMdim * TNdim];
118- const uint32_t lda = TMdim;
119- const uint32_t ldb = TKdim;
120- const uint32_t ldc = TMdim;
121- const uint32_t batch_stride_a = TMdim * TKdim;
122- const uint32_t batch_stride_b = TKdim * TNdim;
123164 mini_jit::Kernel native_kernel;
124165
166+ GemmMxNxKxBatchTestFixture () = delete ;
167+ GemmMxNxKxBatchTestFixture (uint32_t M, uint32_t N, uint32_t K, uint32_t BatchSize);
168+ ~GemmMxNxKxBatchTestFixture ();
169+
125170 /* *
126171 * @brief Set up the test fixture object.
127172 *
128173 * @param fillType Fills the matrices with the given infill type.
129174 */
130- void SetUp (TestInfill fillType)
131- {
132- switch (fillType)
133- {
134- case TestInfill::Random:
135- fill_random_matrix (matrix_a);
136- fill_random_matrix (matrix_b);
137- fill_random_matrix (matrix_c);
138- break ;
139- case TestInfill::Counting:
140- fill_counting_matrix (matrix_a);
141- fill_counting_matrix (matrix_b);
142- fill_counting_matrix (matrix_c);
143- break ;
144- default :
145- FAIL (" Undefined infill type found." );
146- break ;
147- }
148-
149- copy_matrix (matrix_c, matrix_c_verify);
150- }
175+ void SetUp (TestInfill fillType);
151176
152177 /* *
153178 * @brief Executes the Test von an BRGemm with the given input.
@@ -173,36 +198,20 @@ template <uint32_t TMdim, uint32_t TNdim, uint32_t TKdim, uint32_t TBatchDim> cl
173198 * @param br_stride_a: stride between two A matrices (in elements, not bytes).
174199 * @param br_stride_b: stride between two B matrices (in elements, not bytes).
175200 */
176- void _RunTest (const uint32_t lda, const uint32_t ldb, const uint32_t ldc, const uint32_t batch_stride_a, const uint32_t batch_stride_b)
177- {
178- if (native_kernel.get_size () <= 0 )
179- {
180- INFO (" The kernel should contain instructions before the test is executed." );
181- REQUIRE (native_kernel.get_size () > 0 );
182- }
183-
184- // Generate executable kernel
185- native_kernel.set_kernel ();
186- mini_jit::Brgemm::kernel_t kernel = reinterpret_cast <mini_jit::Brgemm::kernel_t >(
187- const_cast <void *>(native_kernel.get_kernel ())); // Properly cast from const void* to kernel_t
188-
189- // Run matmuls
190- kernel (matrix_a, matrix_b, matrix_c, lda, ldb, ldc, batch_stride_a, batch_stride_b);
191- naive_matmul_M_N_K_Batch<TMdim, TNdim, TKdim, TBatchDim>(matrix_a, matrix_b, matrix_c_verify, lda, ldb, ldc, batch_stride_a,
192- batch_stride_b);
193-
194- verify_matmul (matrix_c_verify, matrix_c);
195- }
201+ void _RunTest (const uint32_t lda, const uint32_t ldb, const uint32_t ldc, const uint32_t batch_stride_a, const uint32_t batch_stride_b);
196202};
197203
198- template <uint32_t TMdim, uint32_t TNdim, uint32_t TKdim>
199- class GemmMxNxKTestFixture : public GemmMxNxKxBatchTestFixture <TMdim, TNdim, TKdim, 1 >
204+ class GemmMxNxKTestFixture : public GemmMxNxKxBatchTestFixture
200205{
201206
202207 void RunTest (const uint32_t lda, const uint32_t ldb, const uint32_t ldc, const uint32_t batch_stride_a,
203208 const uint32_t batch_stride_b) = delete; // delete so not visible in a GemmMxNxKTestFixture object.
204209
205210public:
211+ GemmMxNxKTestFixture () = delete ;
212+ GemmMxNxKTestFixture (uint32_t M, uint32_t N, uint32_t K);
213+ ~GemmMxNxKTestFixture ();
214+
206215 /* *
207216 * @brief Executes the Test von an BRGemm with the given input.
208217 *
@@ -212,7 +221,7 @@ class GemmMxNxKTestFixture : public GemmMxNxKxBatchTestFixture<TMdim, TNdim, TKd
212221 */
213222 void RunTest (const uint32_t lda, const uint32_t ldb, const uint32_t ldc)
214223 {
215- GemmMxNxKxBatchTestFixture<TMdim, TNdim, TKdim, 1 > ::_RunTest (lda, ldb, ldc, 0 , 0 );
224+ GemmMxNxKxBatchTestFixture::_RunTest (lda, ldb, ldc, 0 , 0 );
216225 }
217226};
218227
0 commit comments