Skip to content

Commit 51801e6

Browse files
authored
Refactor xor_basis_walk test for fixed bitset size
1 parent dd87860 commit 51801e6

1 file changed

Lines changed: 12 additions & 13 deletions

File tree

tests/library_checker_aizu_tests/handmade_tests/xor_basis_walk.test.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,42 +3,41 @@
33
#include "../template.hpp"
44
#include "../../../library/contest/random.hpp"
55
#include "../../../library/math/matrix_related/xor_basis_ordered.hpp"
6-
const int B = 18;
7-
vector<bitset<B>> get_all(const vector<bitset<B>>& basis) {
6+
vector<bitset<18>> get_all(const vector<bitset<18>>& basis) {
87
int n = ssize(basis);
9-
vector<bitset<B>> span;
8+
vector<bitset<18>> span;
109
for (int mask = 0; mask < (1 << n); mask++) {
11-
bitset<B> curr_xor;
10+
bitset<18> curr_xor;
1211
assert(curr_xor.none());
1312
for (int bit = 0; bit < n; bit++)
1413
if ((mask >> bit) & 1) curr_xor ^= basis[bit];
1514
span.push_back(curr_xor);
1615
}
17-
ranges::sort(span, {}, [&](const bitset<B>& x) -> long {
16+
ranges::sort(span, {}, [&](const bitset<18>& x) -> long {
1817
return x.to_ulong();
1918
});
2019
return span;
2120
}
2221
int main() {
2322
cin.tie(0)->sync_with_stdio(0);
2423
for (int num_tests = 0; num_tests < 100; num_tests++) {
25-
xor_basis<B> b;
24+
xor_basis<18> b;
2625
int n = rnd(1, 16);
27-
vector<bitset<B>> naive_basis;
26+
vector<bitset<18>> naive_basis;
2827
for (int i = 0; i < n; i++) {
29-
bitset<B> val = rnd(0, (1 << n) - 1);
28+
bitset<18> val = rnd(0, (1 << n) - 1);
3029
if (b.insert(val)) naive_basis.push_back(val);
3130
assert(b.npivot + b.nfree == i + 1);
3231
}
3332
assert(ssize(naive_basis) == b.npivot);
34-
vector<bitset<B>> fast_basis;
35-
for (int i = 0; i < B; i++)
33+
vector<bitset<18>> fast_basis;
34+
for (int i = 0; i < 18; i++)
3635
if (b.basis[i][i]) fast_basis.push_back(b.basis[i]);
37-
vector<bitset<B>> naive_span = get_all(naive_basis);
38-
vector<bitset<B>> fast_span = get_all(fast_basis);
36+
vector<bitset<18>> naive_span = get_all(naive_basis);
37+
vector<bitset<18>> fast_span = get_all(fast_basis);
3938
assert(naive_span == fast_span);
4039
for (int i = 0; i < ssize(naive_span); i++) {
41-
bitset<B> k = i;
40+
bitset<18> k = i;
4241
assert(naive_span[i] == b.walk(k));
4342
}
4443
}

0 commit comments

Comments
 (0)