Skip to content

Commit 470a1fd

Browse files
committed
trying this
1 parent fe6e7d3 commit 470a1fd

File tree

1 file changed

+25
-28
lines changed

1 file changed

+25
-28
lines changed

tests/library_checker_aizu_tests/trees/edge_cd_reroot_dp.test.cpp

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,31 @@
33
#include "../template.hpp"
44
#include "../edge_cd_asserts.hpp"
55
#include "../../../library/trees/edge_cd.hpp"
6+
#include "../../../library/math/mod_int.hpp"
67
int main() {
78
cin.tie(0)->sync_with_stdio(0);
89
int n;
910
cin >> n;
1011
vector<int> a(n);
11-
vector<int64_t> res(n);
12+
vector<mint> res(n);
1213
for (int i = 0; i < n; i++) {
1314
cin >> a[i];
1415
res[i] = a[i];
1516
}
16-
vector<vector<int>> adj(n);
17-
vector<int> b(n - 1), c(n - 1);
17+
vector<basic_string<int>> adj(n);
18+
vector<mint> b(n - 1), c(n - 1);
1819
vector<pair<int, int>> par(n, {-1, -1});
19-
const int mod = 998'244'353;
2020
{
2121
vector<vector<pair<int, int>>> adj_with_id(n);
2222
for (int i = 0; i < n - 1; i++) {
2323
int u, v;
24-
cin >> u >> v >> b[i] >> c[i];
24+
cin >> u >> v >> b[i].x >> c[i].x;
2525
adj[u].push_back(v);
2626
adj[v].push_back(u);
2727
adj_with_id[u].emplace_back(v, i);
2828
adj_with_id[v].emplace_back(u, i);
29-
res[u] += 1LL * b[i] * a[v] + c[i];
30-
res[u] %= mod;
31-
res[v] += 1LL * b[i] * a[u] + c[i];
32-
res[v] %= mod;
29+
res[u] = res[u] + b[i] * a[v] + c[i];
30+
res[v] = res[v] + b[i] * a[u] + c[i];
3331
}
3432
auto dfs = [&](auto&& self, int u) -> void {
3533
for (auto [v, e_id] : adj_with_id[u])
@@ -44,35 +42,34 @@ int main() {
4442
assert(u_low ^ v_low);
4543
return u_low ? par[u].second : par[v].second;
4644
};
47-
{ edge_cd(adj, edge_cd_asserts); }
45+
//{ edge_cd(adj, edge_cd_asserts); }
4846
edge_cd(adj,
49-
[&](const vector<vector<int>>& cd_adj, int cent,
47+
[&](const vector<basic_string<int>>& cd_adj, int cent,
5048
int split) -> void {
51-
array<vector<array<int64_t, 3>>, 2> all_backwards;
52-
array<int64_t, 2> sum_forward = {0, 0},
53-
cnt_nodes = {0, 0};
49+
array<vector<array<mint, 3>>, 2> all_backwards;
50+
array<mint, 2> sum_forward = {0, 0};
51+
array<int, 2> cnt_nodes = {0, 0};
5452
auto dfs = [&](auto&& self, int u, int p,
55-
array<int64_t, 2> forwards,
56-
array<int64_t, 2> backwards,
53+
array<mint, 2> forwards,
54+
array<mint, 2> backwards,
5755
int side) -> void {
5856
all_backwards[side].push_back(
5957
{u, backwards[0], backwards[1]});
60-
sum_forward[side] +=
58+
sum_forward[side] = sum_forward[side] +
6159
forwards[0] * a[u] + forwards[1];
62-
sum_forward[side] %= mod;
6360
cnt_nodes[side]++;
6461
for (int v : cd_adj[u]) {
6562
if (v == p) continue;
6663
int e_id = edge_id(u, v);
6764
// f(x) = ax+b
6865
// g(x) = cx+d
6966
// f(g(x)) = a(cx+d)+b = acx+ad+b
70-
array<int64_t, 2> curr_forw = {
71-
forwards[0] * b[e_id] % mod,
72-
(forwards[0] * c[e_id] + forwards[1]) % mod};
73-
array<int64_t, 2> curr_backw = {
74-
b[e_id] * backwards[0] % mod,
75-
(b[e_id] * backwards[1] + c[e_id]) % mod};
67+
array<mint, 2> curr_forw = {
68+
forwards[0] * b[e_id],
69+
forwards[0] * c[e_id] + forwards[1]};
70+
array<mint, 2> curr_backw = {
71+
backwards[0] * b[e_id],
72+
backwards[1] * b[e_id] + c[e_id]};
7673
self(self, v, u, curr_forw, curr_backw, side);
7774
}
7875
};
@@ -84,13 +81,13 @@ int main() {
8481
for (int side = 0; side < 2; side++) {
8582
for (
8683
auto [u, curr_b, curr_c] : all_backwards[side]) {
87-
res[u] += curr_b * sum_forward[!side] +
88-
cnt_nodes[!side] * curr_c;
89-
res[u] %= mod;
84+
res[u.x] = res[u.x] +
85+
curr_b * sum_forward[!side] +
86+
curr_c * cnt_nodes[!side];
9087
}
9188
}
9289
});
93-
for (int i = 0; i < n; i++) cout << res[i] << ' ';
90+
for (int i = 0; i < n; i++) cout << res[i].x << ' ';
9491
cout << '\n';
9592
return 0;
9693
}

0 commit comments

Comments
 (0)