11#include < gtest/gtest.h>
2+ #include < random>
23#include " SIMDMatrix.h"
34
45static constexpr size_t MATRIX_SIZE_LIMIT = 204 ; // Max. 2MB of heap mem used
5-
66using SIMDMatrix = linear_algebra::SIMDMatrix;
77
8+ static std::random_device dev;
9+ static std::mt19937 mersenneTwister (dev());
10+
11+ static SIMDMatrix naiveMultiplication (const SIMDMatrix& lhs, const SIMDMatrix& rhs)
12+ {
13+ // Assertion is enough here. writing code manually.
14+ // SIMDMatrix is treated as some kind of lib
15+ assert (lhs.getRowCount () == rhs.getColCount ());
16+ SIMDMatrix result (lhs.getRowCount (), rhs.getColCount ());
17+
18+ for (size_t i = 0 ; i < result.getRowCount (); i++)
19+ for (size_t j = 0 ; j < result.getColCount (); j++)
20+ {
21+ float dp = 0 .0f ;
22+
23+ for (size_t k = 0 ; k < lhs.getColCount (); k++)
24+ dp += lhs.get (i, k) * rhs.get (k, j);
25+
26+ result.set (i, j, dp);
27+ }
28+
29+ return result;
30+ }
31+
32+ static SIMDMatrix naiveAddition (const SIMDMatrix& lhs, const SIMDMatrix& rhs)
33+ {
34+ assert (lhs.getColCount () == rhs.getColCount ());
35+ assert (lhs.getRowCount () == rhs.getRowCount ());
36+
37+ SIMDMatrix result (lhs.getRowCount (), lhs.getColCount ());
38+
39+ for (size_t i = 0 ; i < result.getRowCount (); i++)
40+ for (size_t j = 0 ; j < result.getColCount (); j++)
41+ result.set (i, j, lhs.get (i, j) + rhs.get (i, j));
42+
43+ return result;
44+ }
45+
46+ static SIMDMatrix naiveScalarMul (float scalar, const SIMDMatrix& rhs)
47+ {
48+ SIMDMatrix result (rhs.getRowCount (), rhs.getColCount ());
49+
50+ for (size_t i = 0 ; i < result.getRowCount (); i++)
51+ for (size_t j = 0 ; j < result.getColCount (); j++)
52+ {
53+ result.set (i, j, rhs.get (i, j) * scalar);
54+ }
55+
56+ return result;
57+ }
58+
59+ static float genRandFloat (float min, float max)
60+ {
61+ std::uniform_real_distribution<float > dist (min, max);
62+ return dist (mersenneTwister);
63+ }
64+
65+ static SIMDMatrix genRandMatrix (size_t rows, size_t cols, float min, float max)
66+ {
67+ SIMDMatrix mat (rows, cols);
68+ std::uniform_real_distribution<float > dist (min, max);
69+
70+ for (size_t i = 0 ; i < rows; i++)
71+ for (size_t j = 0 ; j < cols; j++)
72+ {
73+ mat.set (i, j, dist (mersenneTwister));
74+ }
75+
76+ return mat;
77+ }
78+
879TEST (SIMDMatrix, IdentityScalarMultiplication)
980{
10- for (size_t size = 1 ; size <= MATRIX_SIZE_LIMIT; size++)
81+ for (size_t size = 2 ; size <= MATRIX_SIZE_LIMIT; size++)
1182 {
1283 SIMDMatrix mat = SIMDMatrix::Identity (size);
1384 mat *= 5 ;
@@ -26,10 +97,10 @@ TEST(SIMDMatrix, IdentityScalarMultiplication)
2697
2798TEST (SIMDMatrix, IdentityAddition)
2899{
29- for (size_t size = 1 ; size <= MATRIX_SIZE_LIMIT; size++)
100+ for (size_t size = 2 ; size <= MATRIX_SIZE_LIMIT; size++)
30101 {
31102 SIMDMatrix mat = SIMDMatrix::Identity (size);
32- mat += SIMDMatrix::Identity (size) * 27 ;
103+ mat += SIMDMatrix::Identity (size) * 27 . 0f ;
33104
34105 for (size_t i = 0 ; i < size; i++)
35106 for (size_t j = 0 ; j < size; j++)
@@ -45,10 +116,10 @@ TEST(SIMDMatrix, IdentityAddition)
45116
46117TEST (SIMDMatrix, IdentityMultiplication)
47118{
48- for (size_t size = 1 ; size <= MATRIX_SIZE_LIMIT; size++)
119+ for (size_t size = 2 ; size <= MATRIX_SIZE_LIMIT; size++)
49120 {
50- SIMDMatrix matA = SIMDMatrix::Identity (size) * 3 ;
51- SIMDMatrix matB = SIMDMatrix::Identity (size) * 9 ;
121+ SIMDMatrix matA = SIMDMatrix::Identity (size) * 3 . 0f ;
122+ SIMDMatrix matB = SIMDMatrix::Identity (size) * 9 . 0f ;
52123
53124 SIMDMatrix res = matA * matB;
54125
@@ -62,4 +133,58 @@ TEST(SIMDMatrix, IdentityMultiplication)
62133 EXPECT_FLOAT_EQ (val, 0 .0f );
63134 }
64135 }
136+ }
137+
138+ TEST (SIMDMatrix, Multiplication)
139+ {
140+ for (size_t i = 2 ; i <= MATRIX_SIZE_LIMIT; i++)
141+ for (size_t j = 2 ; j <= MATRIX_SIZE_LIMIT; j++)
142+ {
143+ SIMDMatrix mat1 = genRandMatrix (i, j, 0 .0f , 10 .0f );
144+ SIMDMatrix mat2 = genRandMatrix (j, i, 0 .0f , 10 .0f );
145+
146+ SIMDMatrix res = mat1 * mat2;
147+ SIMDMatrix resCmp = naiveMultiplication (mat1, mat2);
148+
149+ ASSERT_EQ (res.getRowCount (), resCmp.getColCount ());
150+ ASSERT_EQ (res.getColCount (), resCmp.getColCount ());
151+
152+ for (size_t r = 0 ; r < res.getRowCount (); r++)
153+ for (size_t c = 0 ; c < res.getColCount (); c++)
154+ EXPECT_NEAR (res.get (r, c), resCmp.get (r, c), 5e-3 );
155+ }
156+ }
157+
158+ TEST (SIMDMatrix, Addition)
159+ {
160+ for (size_t i = 0 ; i < MATRIX_SIZE_LIMIT; i++)
161+ for (size_t j = 0 ; j < MATRIX_SIZE_LIMIT; j++)
162+ {
163+ SIMDMatrix mat1 = genRandMatrix (i, j, 0 .0f , 10 .0f );
164+ SIMDMatrix mat2 = genRandMatrix (i, j, 0 .0f , 10 .0f );
165+
166+ auto res = mat1 + mat2;
167+ auto resCmp = naiveAddition (mat1, mat2);
168+
169+ for (size_t r = 0 ; r < res.getRowCount (); r++)
170+ for (size_t c = 0 ; c < res.getColCount (); c++)
171+ EXPECT_NEAR (res.get (r, c), resCmp.get (r, c), 5e-3 );
172+ }
173+ }
174+
175+ TEST (SIMDMatrix, ScalarMultiplication)
176+ {
177+ for (size_t i = 0 ; i < MATRIX_SIZE_LIMIT; i++)
178+ for (size_t j = 0 ; j < MATRIX_SIZE_LIMIT; j++)
179+ {
180+ float scalar = genRandFloat (0 .0f , 10 .0f );
181+ SIMDMatrix mat = genRandMatrix (i, j, 0 .0f , 10 .0f );
182+
183+ auto res = scalar * mat;
184+ auto resCmp = naiveScalarMul (scalar, mat);
185+
186+ for (size_t r = 0 ; r < res.getRowCount (); r++)
187+ for (size_t c = 0 ; c < res.getColCount (); c++)
188+ EXPECT_NEAR (res.get (r, c), resCmp.get (r, c), 5e-3 );
189+ }
65190}
0 commit comments