|
3 | 3 | #include "../template.hpp" |
4 | 4 | #include "../../../library/contest/random.hpp" |
5 | 5 | #include "../../../library/math/matrix_related/xor_basis_ordered.hpp" |
6 | | -const int B = 18; |
7 | | -vector<bitset<B>> get_all(const vector<bitset<B>>& basis) { |
| 6 | +vector<bitset<18>> get_all(const vector<bitset<18>>& basis) { |
8 | 7 | int n = ssize(basis); |
9 | | - vector<bitset<B>> span; |
| 8 | + vector<bitset<18>> span; |
10 | 9 | for (int mask = 0; mask < (1 << n); mask++) { |
11 | | - bitset<B> curr_xor; |
| 10 | + bitset<18> curr_xor; |
12 | 11 | assert(curr_xor.none()); |
13 | 12 | for (int bit = 0; bit < n; bit++) |
14 | 13 | if ((mask >> bit) & 1) curr_xor ^= basis[bit]; |
15 | 14 | span.push_back(curr_xor); |
16 | 15 | } |
17 | | - ranges::sort(span, {}, [&](const bitset<B>& x) -> long { |
| 16 | + ranges::sort(span, {}, [&](const bitset<18>& x) -> long { |
18 | 17 | return x.to_ulong(); |
19 | 18 | }); |
20 | 19 | return span; |
21 | 20 | } |
22 | 21 | int main() { |
23 | 22 | cin.tie(0)->sync_with_stdio(0); |
24 | 23 | for (int num_tests = 0; num_tests < 100; num_tests++) { |
25 | | - xor_basis<B> b; |
| 24 | + xor_basis<18> b; |
26 | 25 | int n = rnd(1, 16); |
27 | | - vector<bitset<B>> naive_basis; |
| 26 | + vector<bitset<18>> naive_basis; |
28 | 27 | for (int i = 0; i < n; i++) { |
29 | | - bitset<B> val = rnd(0, (1 << n) - 1); |
| 28 | + bitset<18> val = rnd(0, (1 << n) - 1); |
30 | 29 | if (b.insert(val)) naive_basis.push_back(val); |
31 | 30 | assert(b.npivot + b.nfree == i + 1); |
32 | 31 | } |
33 | 32 | assert(ssize(naive_basis) == b.npivot); |
34 | | - vector<bitset<B>> fast_basis; |
35 | | - for (int i = 0; i < B; i++) |
| 33 | + vector<bitset<18>> fast_basis; |
| 34 | + for (int i = 0; i < 18; i++) |
36 | 35 | if (b.basis[i][i]) fast_basis.push_back(b.basis[i]); |
37 | | - vector<bitset<B>> naive_span = get_all(naive_basis); |
38 | | - vector<bitset<B>> fast_span = get_all(fast_basis); |
| 36 | + vector<bitset<18>> naive_span = get_all(naive_basis); |
| 37 | + vector<bitset<18>> fast_span = get_all(fast_basis); |
39 | 38 | assert(naive_span == fast_span); |
40 | 39 | for (int i = 0; i < ssize(naive_span); i++) { |
41 | | - bitset<B> k = i; |
| 40 | + bitset<18> k = i; |
42 | 41 | assert(naive_span[i] == b.walk(k)); |
43 | 42 | } |
44 | 43 | } |
|
0 commit comments