22. 并查集

并查集 是一种树形的数据结构,顾名思义,它用于处理一些不相交集合(Disjoint Set)的合并及查询问题。它支持两种操作:

  • 合并(Union):把两个不相交的集合合并为一个集合。

  • 查询(Find):查询两个元素是否在同一个集合中。

并查集不支持集合的分离,但是并查集在经过修改后可以支持集合中单个元素的删除操作。

并查集的重要思想在于,用集合中的一个元素(根节点)代表集合。

22.1. 简单版本

初始化

假设有 \(n\) 个元素,用一个数组 parent[] 来存储每个元素的父节点;初始时,将它们的父节点设为自己。

1int parent[MAXN];
2inline void init(const int n)
3{
4  for(int i = 0; i <= n; ++i) parent[i] = i;
5}

查询

用递归的写法实现对 代表元素 的查询:层层向上访问父节点,直至根节点(根节点的标志就是:父节点是本身)。要判断两个元素是否属于同一个集合,只需要看它们的根节点是否相同即可。

1inline int find(const int x)
2{
3  if(parent[x] == x) return x;
4  else return find(parent[x]);
5}

合并

先找到两个集合的代表元素,然后将前者的父节点设为后者即可(当然也可以将后者的父节点设为前者)。

1inline void union(const int x, const int y)
2{
3  parent[find(x)] = find(y);
4}

22.2. 路径压缩

简单版本的并查集效率是比较低的,因为集合合并可能会导致树结构深度越来越深,想要从底部找到根节点代价会变得越来越大。

既然我们只关心一个元素对应的根节点,那我们希望每个元素到根节点的路径尽可能短(最好只需要一步)。只要我们在查询的过程中,把沿途的每个节点的父节点都设为根节点即可。这样一来,下次查询的效率就很高。

1inline int find(const int x)
2{
3  if(parent[x] == x) return x;
4  else
5  {
6    parent[x] = find(parent[x]);
7    return parent[x];
8  }
9}

22.3. 启发式合并

合并可能会使树的深度(树中最长链的长度)加深,原来的树中每个元素到根节点的距离都变长了,之后寻找根节点的路径也就会相应变长。虽然有路径压缩,但路径压缩也是会消耗时间的。

启发式合并方法:把简单的树往复杂的树上合并。因为这样合并后,到根节点距离变长的节点个数比较少。

用一个数组 rank[] 记录每个根节点对应的树的深度(非根节点的 rank 相当于以它为根节点的子树的深度)。初始时,把所有元素的 rank(秩)设为 1;合并时把 rank 较小的树往较大的树上合并。

 1inline void init(const int n)
 2{
 3  for(int i = 0; i <= n; ++i)
 4  {
 5    parent[i] = i;
 6    rank[i] = i;
 7  }
 8}
 9
10inline void union(const int x, const int y)
11{
12  const int rx = find(x);
13  const int ry = find(y);
14  if(rank[rx] <= rank[ry]) parent[rx] = ry;
15  else parent[ry] = rx;
16  if(rank[rx] == rank[ry] && rx != ry) rank[ry]++; // 如果深度相同且根节点不同,则新的根节点的深度 +1
17}

由于每一次查询都是对树的一次重构,会把叶节点以及其所有的祖先全部变成根节点的子节点,因此 rank 会失真,无法反应真实的树高。还有一种启发式合并方法是:把节点少的树往节点多的树上合并。

22.4. 复杂度

简单来说,对于有 \(n\) 个元素的并查集,空间复杂度是 \(\mathcal{O}(n)\)\(m\) 次合并、查询操作的摊还时间是 \(\mathcal{O}(m \log^* n)\),其中 \(\log^*\) 是迭代对数( Iterated Logarithm )。

22.5. Python 参考代码

\(\color{darkgreen}{Code}\)

  1"""
  2A union-find disjoint set data structure.
  3"""
  4
  5# 2to3 sanity
  6from __future__ import (
  7    absolute_import, division, print_function, unicode_literals,
  8)
  9
 10# Third-party libraries
 11import numpy as np
 12
 13
 14class UnionFind(object):
 15    """Union-find disjoint sets datastructure.
 16    Union-find is a data structure that maintains disjoint set
 17    (called connected components or components in short) membership,
 18    and makes it easier to merge (union) two components, and to find
 19    if two elements are connected (i.e., belong to the same
 20    component).
 21    This implements the "weighted-quick-union-with-path-compression"
 22    union-find algorithm.  Only works if elements are immutable
 23    objects.
 24    Worst case for union and find: :math:`(N + M \log^* N)`, with
 25    :math:`N` elements and :math:`M` unions. The function
 26    :math:`\log^*` is the number of times needed to take :math:`\log`
 27    of a number until reaching 1. In practice, the amortized cost of
 28    each operation is nearly linear [1]_.
 29    Terms
 30    -----
 31    Component
 32        Elements belonging to the same disjoint set
 33    Connected
 34        Two elements are connected if they belong to the same component.
 35    Union
 36        The operation where two components are merged into one.
 37    Root
 38        An internal representative of a disjoint set.
 39    Find
 40        The operation to find the root of a disjoint set.
 41    Parameters
 42    ----------
 43    elements : NoneType or container, optional, default: None
 44        The initial list of elements.
 45    Attributes
 46    ----------
 47    n_elts : int
 48        Number of elements.
 49    n_comps : int
 50        Number of distjoint sets or components.
 51    Implements
 52    ----------
 53    __len__
 54        Calling ``len(uf)`` (where ``uf`` is an instance of ``UnionFind``)
 55        returns the number of elements.
 56    __contains__
 57        For ``uf`` an instance of ``UnionFind`` and ``x`` an immutable object,
 58        ``x in uf`` returns ``True`` if ``x`` is an element in ``uf``.
 59    __getitem__
 60        For ``uf`` an instance of ``UnionFind`` and ``i`` an integer,
 61        ``res = uf[i]`` returns the element stored in the ``i``-th index.
 62        If ``i`` is not a valid index an ``IndexError`` is raised.
 63    __setitem__
 64        For ``uf`` and instance of ``UnionFind``, ``i`` an integer and ``x``
 65        an immutable object, ``uf[i] = x`` changes the element stored at the
 66        ``i``-th index. If ``i`` is not a valid index an ``IndexError`` is
 67        raised.
 68    .. [1] http://algs4.cs.princeton.edu/lectures/
 69    """
 70
 71    def __init__(self, elements=None):
 72        self.n_elts = 0  # current num of elements
 73        self.n_comps = 0  # the number of disjoint sets or components
 74        self._next = 0  # next available id
 75        self._elts = []  # the elements
 76        self._indx = {}  #  dict mapping elt -> index in _elts
 77        self._par = []  # parent: for the internal tree structure
 78        self._siz = []  # size of the component - correct only for roots
 79
 80        if elements is None:
 81            elements = []
 82        for elt in elements:
 83            self.add(elt)
 84
 85
 86    def __repr__(self):
 87        return  (
 88            '<UnionFind:\n\telts={},\n\tsiz={},\n\tpar={},\nn_elts={},n_comps={}>'
 89            .format(
 90                self._elts,
 91                self._siz,
 92                self._par,
 93                self.n_elts,
 94                self.n_comps,
 95            ))
 96
 97    def __len__(self):
 98        return self.n_elts
 99
100    def __contains__(self, x):
101        return x in self._indx
102
103    def __getitem__(self, index):
104        if index < 0 or index >= self._next:
105            raise IndexError('index {} is out of bound'.format(index))
106        return self._elts[index]
107
108    def __setitem__(self, index, x):
109        if index < 0 or index >= self._next:
110            raise IndexError('index {} is out of bound'.format(index))
111        self._elts[index] = x
112
113    def add(self, x):
114        """Add a single disjoint element.
115        Parameters
116        ----------
117        x : immutable object
118        Returns
119        -------
120        None
121        """
122        if x in self:
123            return
124        self._elts.append(x)
125        self._indx[x] = self._next
126        self._par.append(self._next)
127        self._siz.append(1)
128        self._next += 1
129        self.n_elts += 1
130        self.n_comps += 1
131
132    def find(self, x):
133        """Find the root of the disjoint set containing the given element.
134        Parameters
135        ----------
136        x : immutable object
137        Returns
138        -------
139        int
140            The (index of the) root.
141        Raises
142        ------
143        ValueError
144            If the given element is not found.
145        """
146        if x not in self._indx:
147            raise ValueError('{} is not an element'.format(x))
148
149        p = self._indx[x]
150        while p != self._par[p]:
151            # path compression
152            q = self._par[p]
153            self._par[p] = self._par[q]
154            p = q
155        return p
156
157    def connected(self, x, y):
158        """Return whether the two given elements belong to the same component.
159        Parameters
160        ----------
161        x : immutable object
162        y : immutable object
163        Returns
164        -------
165        bool
166            True if x and y are connected, false otherwise.
167        """
168        return self.find(x) == self.find(y)
169
170    def union(self, x, y):
171        """Merge the components of the two given elements into one.
172        Parameters
173        ----------
174        x : immutable object
175        y : immutable object
176        Returns
177        -------
178        None
179        """
180        # Initialize if they are not already in the collection
181        for elt in [x, y]:
182            if elt not in self:
183                self.add(elt)
184
185        xroot = self.find(x)
186        yroot = self.find(y)
187        if xroot == yroot:
188            return
189        if self._siz[xroot] < self._siz[yroot]:
190            self._par[xroot] = yroot
191            self._siz[yroot] += self._siz[xroot]
192        else:
193            self._par[yroot] = xroot
194            self._siz[xroot] += self._siz[yroot]
195        self.n_comps -= 1
196
197    def component(self, x):
198        """Find the connected component containing the given element.
199        Parameters
200        ----------
201        x : immutable object
202        Returns
203        -------
204        set
205        Raises
206        ------
207        ValueError
208            If the given element is not found.
209        """
210        if x not in self:
211            raise ValueError('{} is not an element'.format(x))
212        elts = np.array(self._elts)
213        vfind = np.vectorize(self.find)
214        roots = vfind(elts)
215        return set(elts[roots == self.find(x)])
216
217    def components(self):
218        """Return the list of connected components.
219        Returns
220        -------
221        list
222            A list of sets.
223        """
224        elts = np.array(self._elts)
225        vfind = np.vectorize(self.find)
226        roots = vfind(elts)
227        distinct_roots = set(roots)
228        return [set(elts[roots == root]) for root in distinct_roots]
229        # comps = []
230        # for root in distinct_roots:
231        #     mask = (roots == root)
232        #     comp = set(elts[mask])
233        #     comps.append(comp)
234        # return comps
235
236    def component_mapping(self):
237        """Return a dict mapping elements to their components.
238        The returned dict has the following semantics:
239            `elt -> component containing elt`
240        If x, y belong to the same component, the comp(x) and comp(y)
241        are the same objects (i.e., share the same reference). Changing
242        comp(x) will reflect in comp(y).  This is done to reduce
243        memory.
244        But this behaviour should not be relied on.  There may be
245        inconsitency arising from such assumptions or lack thereof.
246        If you want to do any operation on these sets, use caution.
247        For example, instead of
248        ::
249            s = uf.component_mapping()[item]
250            s.add(stuff)
251            # This will have side effect in other sets
252        do
253        ::
254            s = set(uf.component_mapping()[item]) # or
255            s = uf.component_mapping()[item].copy()
256            s.add(stuff)
257        or
258        ::
259            s = uf.component_mapping()[item]
260            s = s | {stuff}  # Now s is different
261        Returns
262        -------
263        dict
264            A dict with the semantics: `elt -> component contianing elt`.
265        """
266        elts = np.array(self._elts)
267        vfind = np.vectorize(self.find)
268        roots = vfind(elts)
269        distinct_roots = set(roots)
270        comps = {}
271        for root in distinct_roots:
272            mask = (roots == root)
273            comp = set(elts[mask])
274            comps.update({x: comp for x in comp})
275            # Change ^this^, if you want a different behaviour:
276            # If you don't want to share the same set to different keys:
277            # comps.update({x: set(comp) for x in comp})
278        return comps

22.6. Kruskal 算法

最小生成树算法中的 Kruskal 算法是基于并查集实现的。首先,将边集合放入优先队列,权重越小的边越靠近队首(小顶堆);然后,边依次出队,如果边的两个顶点位于两个集合,则将它们合并,边权重累加;当合并两个集合之后得到的新集合已经包括了所有的顶点,表示已经得到一棵最小生成树。

\(\color{darkgreen}{Code}\)

 1// NC159 最小生成树
 2// https://www.nowcoder.com/practice/735a34ff4672498b95660f43b7fcd628?tpId=117&&tqId=37869&rp=1&ru=/ta/job-code-high&qru=/ta/job-code-high/question-ranking
 3
 4struct comparator
 5{
 6    bool operator()(vector<int>& a, vector<int>& b)
 7    {
 8        return a[2] > b[2]; // 小顶堆
 9    }
10};
11class Solution
12{
13public:
14    /**
15    * 返回最小的花费代价使得这 n 户人家连接起来
16    * @param n int n户人家的村庄
17    * @param cost intvector<vector<>> 一维3个参数,表示连接1个村庄到另外1个村庄的花费的代价
18    * @return int
19    */
20    int miniSpanningTree(int n, vector<vector<int> >& cost)
21    {
22        // write code here
23        if(n <= 1) return 0;
24        vector<int> parents(n+1, 0);
25        iota(parents.begin(), parents.end(), 0);
26        vector<int> capacity(n+1, 1);
27        priority_queue<vector<int>, vector<vector<int> >, comparator> edges;
28        for(auto& edge: cost) edges.push(edge);
29        int c = 0;
30        int v = 0;
31        while(!edges.empty())
32        {
33            auto edge = edges.top();
34            edges.pop();
35            bool u = union_(parents, capacity, edge[0], edge[1], v);
36            if(u) c += edge[2];
37            if(v == n) break; // 已经得到最小生成树
38        }
39        return c;
40    }
41private:
42    int find_(vector<int>& parents, int x)
43    {
44        if(x == parents[x]) return x;
45        else
46        {
47            parents[x] = find_(parents, parents[x]);
48            return parents[x];
49        }
50    }
51    bool union_(vector<int>& parents, vector<int>& capacity, int x, int y, int& v)
52    {
53        x = find_(parents, x);
54        y = find_(parents, y);
55        if(x != y)
56        {
57            if(capacity[x] >= capacity[y])
58            {
59                parents[y] = x;
60                capacity[x] += capacity[y];
61                v = capacity[x];
62            }
63            else
64            {
65                parents[x] = y;
66                capacity[y] += capacity[x];
67                v = capacity[y];
68            }
69            return true;
70        }
71        return false;
72    }
73};

22.7. 参考资料

  1. 算法学习笔记(1) : 并查集

  1. 并查集

  1. 并查集入门

  1. github

  1. Disjoint-set data structure