1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
//! Union find
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
enum Node {
    Root(usize),
    Child(usize),
}
///UnionFind
#[derive(Clone, Debug)]
pub struct DisjointSetUnion {
    uf: Vec<Node>,
}

impl DisjointSetUnion {
    pub fn new(n: usize) -> DisjointSetUnion {
        DisjointSetUnion {
            uf: vec![Node::Root(1); n],
        }
    }

    pub fn root(&mut self, target: usize) -> usize {
        match self.uf[target] {
            Node::Root(_) => target,
            Node::Child(par) => {
                let root = self.root(par);
                self.uf[target] = Node::Child(root);
                root
            }
        }
    }
    pub fn unite(&mut self, x: usize, y: usize) -> bool {
        let rx = self.root(x);
        let ry = self.root(y);
        if rx == ry {
            return false;
        }
        let size_x = self.size(x);
        let size_y = self.size(y);

        let (i, j) = if size_x > size_y { (rx, ry) } else { (ry, rx) };
        self.uf[i] = Node::Root(size_x + size_y);
        self.uf[j] = Node::Child(i);

        true
    }
    pub fn is_same(&mut self, x: usize, y: usize) -> bool {
        self.root(x) == self.root(y)
    }
    pub fn size(&mut self, x: usize) -> usize {
        let root = self.root(x);
        match self.uf[root] {
            Node::Root(size) => size,
            Node::Child(_) => 0,
        }
    }
    pub fn get_same_group(&mut self, x: usize) -> HashSet<usize> {
        let root = self.root(x);
        let mut g = HashSet::new();
        for i in 0..self.uf.len() {
            if root == self.root(i) {
                g.insert(i);
            }
        }
        g
    }
    pub fn get_all_groups(&mut self) -> HashMap<usize, HashSet<usize>> {
        let mut map: HashMap<usize, HashSet<usize>> = HashMap::new();
        for i in 0..self.uf.len() {
            let root = self.root(i);

            map.entry(root).or_insert_with(HashSet::new).insert(i);
        }
        map
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_dsu() {
        let mut d = DisjointSetUnion::new(4);
        d.unite(0, 1);
        assert!(d.is_same(0, 1));
        d.unite(1, 2);
        assert!(d.is_same(0, 2));
        assert_eq!(d.size(0), 3);
        assert!(!d.is_same(0, 3));

        // assert_eq!(d.get_all_groups(), vec![vec![0, 1, 2], vec![3]]);
    }
}