Skip to content

Commit f10afed

Browse files
add cent decomp (#38)
* add cent decomp * complete docs * fix docs * increase TL for naive conv * clean up dfs order to a bfs * single core * undo ci changes * add ripped fft * fix clippy * fix doc
1 parent 27dd9c7 commit f10afed

File tree

7 files changed

+342
-0
lines changed

7 files changed

+342
-0
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ path = "examples/data_structures/seg_tree_build_on_array.rs"
4646
name = "trie"
4747
path = "examples/data_structures/trie.rs"
4848

49+
[[example]]
50+
name = "count_paths_per_length"
51+
path = "examples/graphs/count_paths_per_length.rs"
52+
4953
[[example]]
5054
name = "dijk_aizu"
5155
path = "examples/graphs/dijk_aizu.rs"
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// verification-helper: PROBLEM https://judge.yosupo.jp/problem/frequency_table_of_tree_distance
2+
3+
use proconio::input;
4+
use programming_team_code_rust::graphs::count_paths_per_length::count_paths_per_length;
5+
6+
fn main() {
7+
input! {
8+
n: usize
9+
}
10+
11+
let mut adj = vec![vec![]; n];
12+
for _ in 0..n - 1 {
13+
input! {
14+
u: usize,
15+
v: usize
16+
}
17+
adj[u].push(v);
18+
adj[v].push(u);
19+
}
20+
21+
let paths_per_length = count_paths_per_length(&adj);
22+
23+
println!(
24+
"{}",
25+
paths_per_length
26+
.iter()
27+
.skip(1)
28+
.map(|&x| x.to_string())
29+
.collect::<Vec<_>>()
30+
.join(" ")
31+
);
32+
}

src/graphs/cent_decomp.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//! # Centroid Decomposition
2+
3+
fn calc_sz(adj: &[Vec<usize>], u: usize, p: usize, sub_sz: &mut [usize]) {
4+
sub_sz[u] = 1;
5+
for &v in adj[u].iter() {
6+
if v != p {
7+
calc_sz(adj, v, u, sub_sz);
8+
sub_sz[u] += sub_sz[v];
9+
}
10+
}
11+
}
12+
13+
fn dfs(
14+
adj: &mut [Vec<usize>],
15+
mut u: usize,
16+
sub_sz: &mut [usize],
17+
call_dfs: &mut dyn CentDecompDfs,
18+
) {
19+
calc_sz(adj, u, u, sub_sz);
20+
let sz_root = sub_sz[u];
21+
let mut p = u;
22+
loop {
23+
let big_ch = adj[u]
24+
.iter()
25+
.filter(|&&v| v != p)
26+
.find(|&&v| sub_sz[v] * 2 > sz_root);
27+
if let Some(&v) = big_ch {
28+
p = u;
29+
u = v;
30+
} else {
31+
break;
32+
}
33+
}
34+
call_dfs.dfs(adj, u);
35+
for v in adj[u].clone() {
36+
adj[v].retain(|&x| x != u);
37+
dfs(adj, v, sub_sz, call_dfs);
38+
}
39+
}
40+
41+
/// A trait containing the DFS method which is called on each centroid of the tree with the
42+
/// back-edges outside of this centroid removed from `adj`
43+
pub trait CentDecompDfs {
44+
/// The DFS method
45+
fn dfs(&mut self, adj: &[Vec<usize>], cent: usize);
46+
}
47+
48+
/// # Example
49+
/// - see count_paths_per_length.rs
50+
///
51+
/// # Params
52+
/// - `adj`: adjacency list representing an unrooted undirected tree
53+
/// - `call_dfs`: an object implementing `CentDecompDfs` trait
54+
///
55+
/// # Complexity
56+
/// - Time: O(n log n)
57+
/// - Space: O(n)
58+
pub fn cent_decomp(mut adj: Vec<Vec<usize>>, call_dfs: &mut dyn CentDecompDfs) {
59+
let n = adj.len();
60+
let mut sub_sz = vec![0; n];
61+
for s in 0..n {
62+
if sub_sz[s] == 0 {
63+
dfs(&mut adj, s, &mut sub_sz, call_dfs);
64+
}
65+
}
66+
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
//! # Count the number of paths of each length in a tree
2+
use crate::graphs::cent_decomp::{cent_decomp, CentDecompDfs};
3+
use crate::numbers::fft::fft_multiply;
4+
5+
fn conv(a: &[u64], b: &[u64]) -> Vec<u64> {
6+
let a = a.iter().map(|&x| x as f64).collect::<Vec<_>>();
7+
let b = b.iter().map(|&x| x as f64).collect::<Vec<_>>();
8+
let res_len = a.len() + b.len() - 1;
9+
let mut res = fft_multiply(a, b);
10+
res.resize(res_len, 0.0);
11+
res.iter().map(|&x| x.round() as u64).collect()
12+
}
13+
14+
struct CountPathsPerLength {
15+
num_paths: Vec<u64>,
16+
}
17+
18+
impl CountPathsPerLength {
19+
fn new(n: usize) -> Self {
20+
Self {
21+
num_paths: vec![0; n],
22+
}
23+
}
24+
}
25+
26+
impl CentDecompDfs for CountPathsPerLength {
27+
fn dfs(&mut self, adj: &[Vec<usize>], cent: usize) {
28+
let mut child_depths = vec![vec![]];
29+
for &child in adj[cent].iter() {
30+
let mut my_child_depths = vec![0];
31+
32+
use std::collections::VecDeque;
33+
34+
let mut q = VecDeque::new();
35+
q.push_back((child, cent));
36+
37+
while !q.is_empty() {
38+
my_child_depths.push(q.len() as u64);
39+
40+
let mut new_q = VecDeque::new();
41+
while let Some((u, p)) = q.pop_front() {
42+
for &v in adj[u].iter() {
43+
if v != p {
44+
new_q.push_back((v, u));
45+
}
46+
}
47+
}
48+
49+
q = new_q;
50+
}
51+
52+
child_depths.push(my_child_depths);
53+
}
54+
55+
child_depths.sort_by_key(|v| v.len());
56+
57+
let mut acc = vec![1];
58+
for depth_arr in child_depths {
59+
let res = conv(&acc, &depth_arr);
60+
for (d, &cnt) in res.iter().enumerate() {
61+
self.num_paths[d] += cnt;
62+
}
63+
64+
acc.resize(acc.len().max(depth_arr.len()), 0);
65+
for (d, &cnt) in depth_arr.iter().enumerate() {
66+
acc[d] += cnt;
67+
}
68+
}
69+
}
70+
}
71+
72+
/// # Example
73+
/// ```
74+
/// use programming_team_code_rust::graphs::count_paths_per_length::count_paths_per_length;
75+
/// let adj = vec![
76+
/// vec![1, 2],
77+
/// vec![0, 3, 4],
78+
/// vec![0, 5, 6],
79+
/// vec![1],
80+
/// vec![1],
81+
/// vec![2],
82+
/// vec![2],
83+
/// ];
84+
/// let res = count_paths_per_length(&adj);
85+
/// assert_eq!(res, vec![0, 6, 7, 4, 4, 0, 0]);
86+
/// ```
87+
///
88+
/// # Complexity
89+
/// - Time: O(n log^2 n) with FFT for convolution
90+
/// - Space: O(n)
91+
pub fn count_paths_per_length(adj: &[Vec<usize>]) -> Vec<u64> {
92+
let n = adj.len();
93+
let mut obj = CountPathsPerLength::new(n);
94+
cent_decomp(adj.to_vec(), &mut obj);
95+
obj.num_paths
96+
}

src/graphs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
//! # Graph Algorithms
2+
pub mod cent_decomp;
3+
pub mod count_paths_per_length;
24
mod dfs_order;
35
pub mod dijk;
46
pub mod lca;

src/numbers/fft.rs

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
//! Fast Fourier Transform (FFT) implementation ripped from: <https://github.com/bminaiev/rust-contests/blob/main/algo_lib/src/math/fft.rs>
2+
use std::ops::{Add, Mul, MulAssign, Sub};
3+
4+
#[derive(Copy, Clone)]
5+
struct Complex {
6+
real: f64,
7+
imag: f64,
8+
}
9+
10+
impl Complex {
11+
const ZERO: Self = Complex {
12+
real: 0.0,
13+
imag: 0.0,
14+
};
15+
const ONE: Self = Complex {
16+
real: 1.0,
17+
imag: 0.0,
18+
};
19+
}
20+
21+
impl Mul for Complex {
22+
type Output = Complex;
23+
24+
fn mul(self, rhs: Self) -> Self::Output {
25+
Self {
26+
real: self.real * rhs.real - self.imag * rhs.imag,
27+
imag: self.real * rhs.imag + self.imag * rhs.real,
28+
}
29+
}
30+
}
31+
32+
impl MulAssign for Complex {
33+
fn mul_assign(&mut self, rhs: Self) {
34+
*self = *self * rhs;
35+
}
36+
}
37+
38+
impl Add for Complex {
39+
type Output = Complex;
40+
41+
fn add(self, rhs: Self) -> Self::Output {
42+
Self {
43+
real: self.real + rhs.real,
44+
imag: self.imag + rhs.imag,
45+
}
46+
}
47+
}
48+
49+
impl Sub for Complex {
50+
type Output = Complex;
51+
52+
fn sub(self, rhs: Self) -> Self::Output {
53+
Self {
54+
real: self.real - rhs.real,
55+
imag: self.imag - rhs.imag,
56+
}
57+
}
58+
}
59+
60+
fn fft(a: &mut [Complex], invert: bool) {
61+
let n = a.len();
62+
assert!(n.is_power_of_two());
63+
let shift = usize::BITS - n.trailing_zeros();
64+
for i in 1..n {
65+
let j = (i << shift).reverse_bits();
66+
assert!(j < n);
67+
if i < j {
68+
a.swap(i, j);
69+
}
70+
}
71+
for len in (1..).map(|x| 1 << x).take_while(|s| *s <= n) {
72+
let half = len / 2;
73+
let alpha = std::f64::consts::PI * 2.0 / (len as f64);
74+
let cos = f64::cos(alpha);
75+
let sin = f64::sin(alpha) * (if invert { -1.0 } else { 1.0 });
76+
let complex_angle = Complex {
77+
real: cos,
78+
imag: sin,
79+
};
80+
for start in (0..n).step_by(len) {
81+
let mut mult = Complex::ONE;
82+
for j in 0..half {
83+
let u = a[start + j];
84+
let v = a[start + half + j] * mult;
85+
a[start + j] = u + v;
86+
a[start + j + half] = u - v;
87+
mult *= complex_angle;
88+
}
89+
}
90+
}
91+
if invert {
92+
for elem in a.iter_mut().take(n) {
93+
let n = n as f64;
94+
elem.imag /= n;
95+
elem.real /= n;
96+
}
97+
}
98+
}
99+
100+
fn fft_multiply_raw(mut a: Vec<Complex>, mut b: Vec<Complex>) -> Vec<Complex> {
101+
assert!(a.len().is_power_of_two());
102+
assert!(b.len().is_power_of_two());
103+
assert_eq!(a.len(), b.len());
104+
fft(&mut a, false);
105+
fft(&mut b, false);
106+
for (x, y) in a.iter_mut().zip(b.iter()) {
107+
*x *= *y;
108+
}
109+
fft(&mut a, true);
110+
a
111+
}
112+
113+
fn fft_multiply_complex(mut a: Vec<Complex>, mut b: Vec<Complex>) -> Vec<Complex> {
114+
let expected_size = (a.len() + b.len() - 1).next_power_of_two();
115+
a.resize(expected_size, Complex::ZERO);
116+
b.resize(expected_size, Complex::ZERO);
117+
fft_multiply_raw(a, b)
118+
}
119+
120+
/// # Example
121+
/// ```
122+
/// use programming_team_code_rust::numbers::fft::fft_multiply;
123+
///
124+
/// let a = vec![1.0, 2.0, 3.0];
125+
/// let b = vec![4.0, 5.0, 6.0];
126+
/// let c = fft_multiply(a, b);
127+
/// let expected = vec![4.0, 13.0, 28.0, 27.0, 18.0];
128+
/// for (x, y) in c.iter().zip(expected.iter()) {
129+
/// assert!((x - y).abs() < 1e-15);
130+
/// }
131+
/// ```
132+
///
133+
/// # Complexity (n = max(a.len(), b.len()))
134+
/// Given for practical use cases, although consider it with a relatively large constant factor
135+
/// - Time: O(n log n)
136+
/// - Space: O(n)
137+
pub fn fft_multiply(a: Vec<f64>, b: Vec<f64>) -> Vec<f64> {
138+
let a: Vec<_> = a.iter().map(|&x| Complex { real: x, imag: 0.0 }).collect();
139+
let b: Vec<_> = b.iter().map(|&x| Complex { real: x, imag: 0.0 }).collect();
140+
fft_multiply_complex(a, b).iter().map(|c| c.real).collect()
141+
}

src/numbers/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
//! # Number Theory and Combinatorics
22
pub mod binom;
3+
pub mod fft;
34
pub mod primes;

0 commit comments

Comments
 (0)