Skip to content

Commit a3ec772

Browse files
committed
modernize edge CD
1 parent b4f4a61 commit a3ec772

2 files changed

Lines changed: 29 additions & 33 deletions

File tree

library/trees/edge_cd.hpp

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,39 +13,36 @@
1313
//! handle single-edge-paths separately
1414
//! @time O(n logφ n)
1515
//! @space O(n)
16-
template<class F, class G> struct edge_cd {
17-
vector<G> adj;
18-
F f;
19-
vi siz;
20-
edge_cd(const vector<G>& adj, F f):
21-
adj(adj), f(f), siz(sz(adj)) {
22-
dfs(0, sz(adj) - 1);
23-
}
24-
int find_cent(int u, int p, int m) {
16+
template<class G>
17+
void edge_cd(vector<G>& adj, const auto& f) {
18+
vi siz(sz(adj));
19+
auto find_cent = [&](auto&& self, int u, int p,
20+
int m) -> int {
2521
siz[u] = 1;
2622
for (int v : adj[u])
2723
if (v != p) {
28-
int cent = find_cent(v, u, m);
24+
int cent = self(self, v, u, m);
2925
if (cent != -1) return cent;
3026
siz[u] += siz[v];
3127
}
3228
return 2 * siz[u] > m
3329
? p >= 0 && (siz[p] = m + 1 - siz[u]),
3430
u : -1;
35-
}
36-
void dfs(int u, int m) {
31+
};
32+
auto dfs = [&](auto&& self, int u, int m) -> void {
3733
if (m < 2) return;
38-
u = find_cent(u, -1, m);
34+
u = find_cent(find_cent, u, -1, m);
3935
int sum = 0;
4036
auto it = partition(all(adj[u]), [&](int v) {
4137
ll x = sum + siz[v];
4238
return x * x < m * (m - x) ? sum += siz[v], 1 : 0;
4339
});
44-
f(adj, u, it - begin(adj[u]));
40+
f(u, it - begin(adj[u]));
4541
G oth(it, end(adj[u]));
4642
adj[u].erase(it, end(adj[u]));
47-
dfs(u, sum);
43+
self(self, u, sum);
4844
swap(adj[u], oth);
49-
dfs(u, m - sum);
50-
}
45+
self(self, u, m - sum);
46+
};
47+
dfs(dfs, 0, sz(adj) - 1);
5148
};

tests/library_checker_aizu_tests/trees/count_paths_per_length.test.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,23 @@
99
//! @time O(n * logφ(n) * log2(n))
1010
//! @space this function allocates/returns various vectors
1111
//! which are each O(n)
12-
vector<ll> count_paths_per_length(const vector<vi>& adj) {
12+
vector<ll> count_paths_per_length(vector<vi>& adj) {
1313
vector<ll> num_paths(sz(adj));
1414
if (sz(adj) >= 2) num_paths[1] = sz(adj) - 1;
15-
edge_cd(adj,
16-
[&](const vector<vi>& cd_adj, int cent, int split) {
17-
vector<vector<double>> cnt(2, vector<double>(1));
18-
auto dfs = [&](auto&& self, int u, int p, int d,
19-
int side) -> void {
20-
if (sz(cnt[side]) == d) cnt[side].push_back(0.0);
21-
cnt[side][d]++;
22-
for (int c : cd_adj[u])
23-
if (c != p) self(self, c, u, 1 + d, side);
24-
};
25-
rep(i, 0, sz(cd_adj[cent]))
26-
dfs(dfs, cd_adj[cent][i], cent, 1, i < split);
27-
vector<double> prod = conv(cnt[0], cnt[1]);
28-
rep(i, 0, sz(prod)) num_paths[i] += llround(prod[i]);
29-
});
15+
edge_cd(adj, [&](int cent, int split) {
16+
vector<vector<double>> cnt(2, vector<double>(1));
17+
auto dfs = [&](auto&& self, int u, int p, int d,
18+
int side) -> void {
19+
if (sz(cnt[side]) == d) cnt[side].push_back(0.0);
20+
cnt[side][d]++;
21+
for (int c : adj[u])
22+
if (c != p) self(self, c, u, 1 + d, side);
23+
};
24+
rep(i, 0, sz(adj[cent]))
25+
dfs(dfs, adj[cent][i], cent, 1, i < split);
26+
vector<double> prod = conv(cnt[0], cnt[1]);
27+
rep(i, 0, sz(prod)) num_paths[i] += llround(prod[i]);
28+
});
3029
return num_paths;
3130
}
3231
int main() {

0 commit comments

Comments
 (0)