33#include " ../template.hpp"
44#include " ../edge_cd_asserts.hpp"
55#include " ../../../library/trees/edge_cd.hpp"
6+ #include " ../../../library/math/mod_int.hpp"
67int main () {
78 cin.tie (0 )->sync_with_stdio (0 );
89 int n;
910 cin >> n;
1011 vector<int > a (n);
11- vector<int64_t > res (n);
12+ vector<mint > res (n);
1213 for (int i = 0 ; i < n; i++) {
1314 cin >> a[i];
1415 res[i] = a[i];
1516 }
16- vector<vector <int >> adj (n);
17- vector<int > b (n - 1 ), c (n - 1 );
17+ vector<basic_string <int >> adj (n);
18+ vector<mint > b (n - 1 ), c (n - 1 );
1819 vector<pair<int , int >> par (n, {-1 , -1 });
19- const int mod = 998'244'353 ;
2020 {
2121 vector<vector<pair<int , int >>> adj_with_id (n);
2222 for (int i = 0 ; i < n - 1 ; i++) {
2323 int u, v;
24- cin >> u >> v >> b[i] >> c[i];
24+ cin >> u >> v >> b[i]. x >> c[i]. x ;
2525 adj[u].push_back (v);
2626 adj[v].push_back (u);
2727 adj_with_id[u].emplace_back (v, i);
2828 adj_with_id[v].emplace_back (u, i);
29- res[u] += 1LL * b[i] * a[v] + c[i];
30- res[u] %= mod;
31- res[v] += 1LL * b[i] * a[u] + c[i];
32- res[v] %= mod;
29+ res[u] = res[u] + b[i] * a[v] + c[i];
30+ res[v] = res[v] + b[i] * a[u] + c[i];
3331 }
3432 auto dfs = [&](auto && self, int u) -> void {
3533 for (auto [v, e_id] : adj_with_id[u])
@@ -44,35 +42,34 @@ int main() {
4442 assert (u_low ^ v_low);
4543 return u_low ? par[u].second : par[v].second ;
4644 };
47- { edge_cd (adj, edge_cd_asserts); }
45+ // { edge_cd(adj, edge_cd_asserts); }
4846 edge_cd (adj,
49- [&](const vector<vector <int >>& cd_adj, int cent,
47+ [&](const vector<basic_string <int >>& cd_adj, int cent,
5048 int split) -> void {
51- array<vector<array<int64_t , 3 >>, 2 > all_backwards;
52- array<int64_t , 2 > sum_forward = {0 , 0 },
53- cnt_nodes = {0 , 0 };
49+ array<vector<array<mint , 3 >>, 2 > all_backwards;
50+ array<mint , 2 > sum_forward = {0 , 0 };
51+ array< int , 2 > cnt_nodes = {0 , 0 };
5452 auto dfs = [&](auto && self, int u, int p,
55- array<int64_t , 2 > forwards,
56- array<int64_t , 2 > backwards,
53+ array<mint , 2 > forwards,
54+ array<mint , 2 > backwards,
5755 int side) -> void {
5856 all_backwards[side].push_back (
5957 {u, backwards[0 ], backwards[1 ]});
60- sum_forward[side] +=
58+ sum_forward[side] = sum_forward[side] +
6159 forwards[0 ] * a[u] + forwards[1 ];
62- sum_forward[side] %= mod;
6360 cnt_nodes[side]++;
6461 for (int v : cd_adj[u]) {
6562 if (v == p) continue ;
6663 int e_id = edge_id (u, v);
6764 // f(x) = ax+b
6865 // g(x) = cx+d
6966 // f(g(x)) = a(cx+d)+b = acx+ad+b
70- array<int64_t , 2 > curr_forw = {
71- forwards[0 ] * b[e_id] % mod ,
72- ( forwards[0 ] * c[e_id] + forwards[1 ]) % mod };
73- array<int64_t , 2 > curr_backw = {
74- b[e_id ] * backwards[ 0 ] % mod ,
75- (b[e_id ] * backwards[ 1 ] + c[e_id]) % mod };
67+ array<mint , 2 > curr_forw = {
68+ forwards[0 ] * b[e_id],
69+ forwards[0 ] * c[e_id] + forwards[1 ]};
70+ array<mint , 2 > curr_backw = {
71+ backwards[ 0 ] * b[e_id] ,
72+ backwards[ 1 ] * b[e_id ] + c[e_id]};
7673 self (self, v, u, curr_forw, curr_backw, side);
7774 }
7875 };
@@ -84,13 +81,13 @@ int main() {
8481 for (int side = 0 ; side < 2 ; side++) {
8582 for (
8683 auto [u, curr_b, curr_c] : all_backwards[side]) {
87- res[u] += curr_b * sum_forward[!side ] +
88- cnt_nodes [!side] * curr_c;
89- res[u] %= mod ;
84+ res[u. x ] = res[u. x ] +
85+ curr_b * sum_forward [!side] +
86+ curr_c * cnt_nodes[!side] ;
9087 }
9188 }
9289 });
93- for (int i = 0 ; i < n; i++) cout << res[i] << ' ' ;
90+ for (int i = 0 ; i < n; i++) cout << res[i]. x << ' ' ;
9491 cout << ' \n ' ;
9592 return 0 ;
9693}
0 commit comments