Skip to content

Commit 44323aa

Browse files
committed
complete tests
1 parent 525a005 commit 44323aa

1 file changed

Lines changed: 132 additions & 7 deletions

File tree

test/test_matrix.cpp

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,84 @@
11
#include <gtest/gtest.h>
2+
#include <random>
23
#include "SIMDMatrix.h"
34

45
static constexpr size_t MATRIX_SIZE_LIMIT = 204; // Max. 2MB of heap mem used
5-
66
using SIMDMatrix = linear_algebra::SIMDMatrix;
77

8+
static std::random_device dev;
9+
static std::mt19937 mersenneTwister(dev());
10+
11+
static SIMDMatrix naiveMultiplication(const SIMDMatrix& lhs, const SIMDMatrix& rhs)
12+
{
13+
// Assertion is enough here. writing code manually.
14+
// SIMDMatrix is treated as some kind of lib
15+
assert(lhs.getRowCount() == rhs.getColCount());
16+
SIMDMatrix result(lhs.getRowCount(), rhs.getColCount());
17+
18+
for (size_t i = 0; i < result.getRowCount(); i++)
19+
for (size_t j = 0; j < result.getColCount(); j++)
20+
{
21+
float dp = 0.0f;
22+
23+
for (size_t k = 0; k < lhs.getColCount(); k++)
24+
dp += lhs.get(i, k) * rhs.get(k, j);
25+
26+
result.set(i, j, dp);
27+
}
28+
29+
return result;
30+
}
31+
32+
static SIMDMatrix naiveAddition(const SIMDMatrix& lhs, const SIMDMatrix& rhs)
33+
{
34+
assert(lhs.getColCount() == rhs.getColCount());
35+
assert(lhs.getRowCount() == rhs.getRowCount());
36+
37+
SIMDMatrix result(lhs.getRowCount(), lhs.getColCount());
38+
39+
for (size_t i = 0; i < result.getRowCount(); i++)
40+
for (size_t j = 0; j < result.getColCount(); j++)
41+
result.set(i, j, lhs.get(i, j) + rhs.get(i, j));
42+
43+
return result;
44+
}
45+
46+
static SIMDMatrix naiveScalarMul(float scalar, const SIMDMatrix& rhs)
47+
{
48+
SIMDMatrix result(rhs.getRowCount(), rhs.getColCount());
49+
50+
for (size_t i = 0; i < result.getRowCount(); i++)
51+
for (size_t j = 0; j < result.getColCount(); j++)
52+
{
53+
result.set(i, j, rhs.get(i, j) * scalar);
54+
}
55+
56+
return result;
57+
}
58+
59+
static float genRandFloat(float min, float max)
60+
{
61+
std::uniform_real_distribution<float> dist(min, max);
62+
return dist(mersenneTwister);
63+
}
64+
65+
static SIMDMatrix genRandMatrix(size_t rows, size_t cols, float min, float max)
66+
{
67+
SIMDMatrix mat(rows, cols);
68+
std::uniform_real_distribution<float> dist(min, max);
69+
70+
for (size_t i = 0; i < rows; i++)
71+
for (size_t j = 0; j < cols; j++)
72+
{
73+
mat.set(i, j, dist(mersenneTwister));
74+
}
75+
76+
return mat;
77+
}
78+
879
TEST(SIMDMatrix, IdentityScalarMultiplication)
980
{
10-
for (size_t size = 1; size <= MATRIX_SIZE_LIMIT; size++)
81+
for (size_t size = 2; size <= MATRIX_SIZE_LIMIT; size++)
1182
{
1283
SIMDMatrix mat = SIMDMatrix::Identity(size);
1384
mat *= 5;
@@ -26,10 +97,10 @@ TEST(SIMDMatrix, IdentityScalarMultiplication)
2697

2798
TEST(SIMDMatrix, IdentityAddition)
2899
{
29-
for (size_t size = 1; size <= MATRIX_SIZE_LIMIT; size++)
100+
for (size_t size = 2; size <= MATRIX_SIZE_LIMIT; size++)
30101
{
31102
SIMDMatrix mat = SIMDMatrix::Identity(size);
32-
mat += SIMDMatrix::Identity(size) * 27;
103+
mat += SIMDMatrix::Identity(size) * 27.0f;
33104

34105
for (size_t i = 0; i < size; i++)
35106
for (size_t j = 0; j < size; j++)
@@ -45,10 +116,10 @@ TEST(SIMDMatrix, IdentityAddition)
45116

46117
TEST(SIMDMatrix, IdentityMultiplication)
47118
{
48-
for (size_t size = 1; size <= MATRIX_SIZE_LIMIT; size++)
119+
for (size_t size = 2; size <= MATRIX_SIZE_LIMIT; size++)
49120
{
50-
SIMDMatrix matA = SIMDMatrix::Identity(size) * 3;
51-
SIMDMatrix matB = SIMDMatrix::Identity(size) * 9;
121+
SIMDMatrix matA = SIMDMatrix::Identity(size) * 3.0f;
122+
SIMDMatrix matB = SIMDMatrix::Identity(size) * 9.0f;
52123

53124
SIMDMatrix res = matA * matB;
54125

@@ -62,4 +133,58 @@ TEST(SIMDMatrix, IdentityMultiplication)
62133
EXPECT_FLOAT_EQ(val, 0.0f);
63134
}
64135
}
136+
}
137+
138+
TEST(SIMDMatrix, Multiplication)
139+
{
140+
for (size_t i = 2; i <= MATRIX_SIZE_LIMIT; i++)
141+
for (size_t j = 2; j <= MATRIX_SIZE_LIMIT; j++)
142+
{
143+
SIMDMatrix mat1 = genRandMatrix(i, j, 0.0f, 10.0f);
144+
SIMDMatrix mat2 = genRandMatrix(j, i, 0.0f, 10.0f);
145+
146+
SIMDMatrix res = mat1 * mat2;
147+
SIMDMatrix resCmp = naiveMultiplication(mat1, mat2);
148+
149+
ASSERT_EQ(res.getRowCount(), resCmp.getColCount());
150+
ASSERT_EQ(res.getColCount(), resCmp.getColCount());
151+
152+
for (size_t r = 0; r < res.getRowCount(); r++)
153+
for (size_t c = 0; c < res.getColCount(); c++)
154+
EXPECT_NEAR(res.get(r, c), resCmp.get(r, c), 5e-3);
155+
}
156+
}
157+
158+
TEST(SIMDMatrix, Addition)
159+
{
160+
for (size_t i = 0; i < MATRIX_SIZE_LIMIT; i++)
161+
for (size_t j = 0; j < MATRIX_SIZE_LIMIT; j++)
162+
{
163+
SIMDMatrix mat1 = genRandMatrix(i, j, 0.0f, 10.0f);
164+
SIMDMatrix mat2 = genRandMatrix(i, j, 0.0f, 10.0f);
165+
166+
auto res = mat1 + mat2;
167+
auto resCmp = naiveAddition(mat1, mat2);
168+
169+
for (size_t r = 0; r < res.getRowCount(); r++)
170+
for (size_t c = 0; c < res.getColCount(); c++)
171+
EXPECT_NEAR(res.get(r, c), resCmp.get(r, c), 5e-3);
172+
}
173+
}
174+
175+
TEST(SIMDMatrix, ScalarMultiplication)
176+
{
177+
for (size_t i = 0; i < MATRIX_SIZE_LIMIT; i++)
178+
for (size_t j = 0; j < MATRIX_SIZE_LIMIT; j++)
179+
{
180+
float scalar = genRandFloat(0.0f, 10.0f);
181+
SIMDMatrix mat = genRandMatrix(i, j, 0.0f, 10.0f);
182+
183+
auto res = scalar * mat;
184+
auto resCmp = naiveScalarMul(scalar, mat);
185+
186+
for (size_t r = 0; r < res.getRowCount(); r++)
187+
for (size_t c = 0; c < res.getColCount(); c++)
188+
EXPECT_NEAR(res.get(r, c), resCmp.get(r, c), 5e-3);
189+
}
65190
}

0 commit comments

Comments
 (0)