Skip to content

Commit 502e483

Browse files
committed
updates
1 parent 6301d13 commit 502e483

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

library/trees/edge_cd.hpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,37 +25,34 @@ template <class F, class G>
2525
struct edge_cd {
2626
vector<G> adj;
2727
F f;
28-
vi sub_sz;
29-
edge_cd(const vector<G>& adj, F f) : adj(adj), f(f), sub_sz(sz(adj)) {
30-
dfs(0, sz(adj));
28+
vi siz;
29+
edge_cd(const vector<G>& adj, F f) : adj(adj), f(f), siz(sz(adj)) {
30+
dfs(0, sz(adj) - 1);
3131
}
32-
int find_cent(int v, int p, int siz) {
33-
sub_sz[v] = 1;
32+
int find_cent(int v, int p, int m) {
33+
siz[v] = 1;
3434
for (int u : adj[v])
3535
if (u != p) {
36-
int cent = find_cent(u, v, siz);
36+
int cent = find_cent(u, v, m);
3737
if (cent != -1) return cent;
38-
sub_sz[v] += sub_sz[u];
38+
siz[v] += siz[u];
3939
}
4040
if (p == -1) return v;
41-
return 2 * sub_sz[v] >= siz ? sub_sz[p] = siz - sub_sz[v], v : -1;
41+
return 2 * siz[v] > m ? siz[p] = m + 1 - siz[v], v : -1;
4242
}
43-
void dfs(int v, int siz) {
44-
if (siz <= 2) return;
45-
v = find_cent(v, -1, siz);
43+
void dfs(int v, int m) {
44+
if (m < 2) return;
45+
v = find_cent(v, -1, m);
4646
int sum = 0;
4747
auto it = partition(all(adj[v]), [&](int u) {
48-
ll b = sum + sub_sz[u];
49-
ll a = siz - 1 - b;
50-
bool ret = (b * b <= a * (a + b));
51-
if (ret) sum += sub_sz[u];
52-
return ret;
48+
ll x = sum + siz[u];
49+
return x * x < m * (m - x) ? sum += siz[u], 1 : 0;
5350
});
5451
f(adj, v, it - begin(adj[v]));
5552
G oth(it, end(adj[v]));
5653
adj[v].erase(it, end(adj[v]));
57-
dfs(v, sum + 1);
54+
dfs(v, sum);
5855
swap(adj[v], oth);
59-
dfs(v, siz - sum);
56+
dfs(v, m - sum);
6057
}
6158
};

tests/library_checker_aizu_tests/edge_cd_asserts.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,7 @@ void edge_cd_asserts(const vector<vi>& adj, int cent, int split) {
4242
assert(b > 0);
4343
if (a > b) swap(a, b);
4444
assert(is_balanced(a, b));
45+
assert(!is_balanced(a, cnts[0] + b));
46+
assert(!is_balanced(b, cnts[0] + a));
4547
}
4648
}

0 commit comments

Comments
 (0)