9. 回溯

总体思路是深度优先遍历(DFS)。

9.1. 子集树

子集树大小为 \(\mathcal{O}(m^n)\)\(m\) 是树的分支个数( \(m\) 叉树),\(n\) 是树的深度。

算法描述:

 1void backtrack(int t)
 2{
 3  if(t >= n) output(x);
 4  else
 5  {
 6    for(int i = 0; i < m; ++i)
 7    {
 8      x[t] = i;
 9      if(constrain(t) and bound(t)) backtrack(t+1);
10    }
11  }
12}

9.2. 排列树

排列树大小为 \(\mathcal{O}(n!)\)

算法描述:

 1void backtrack(int t)
 2{
 3  if(t >= n) output(x);
 4  else
 5  {
 6    for(int k = t; k < n; ++k)
 7    {
 8      swap(x[t], x[k]);
 9      if(constrain(t) and bound(t)) backtrack(t+1);
10      swap(x[t], x[k]);
11    }
12  }
13}

9.3. 0-1背包问题

算法描述:

 1void backtrack(int t)
 2{
 3  if(t >= n)
 4  {
 5    best_value = curr_value;
 6    bext_x = x;
 7    return;
 8  }
 9  else
10  {
11    if(curr_weight + w[t] <= W)
12    {
13      x[t] = 1;
14      curr_weight += w[t]; // 进入左子树
15      curr_value += v[t];
16      backtrack(t+1);
17
18      curr_weight -= w[t]; // 状态恢复
19      curr_value -= v[t];
20    }
21    x[t] = 0;
22    backtrack(t+1); // 进入右子树
23  }
24}

9.4. 八皇后问题

八皇后问题共有 92 组解。

 1bool place(int t, int* x)
 2{
 3  for(int j = 0; j < t; ++j)
 4  {
 5    if(x[j] == x[t] || abs(j - t) == abs(x[j] - x[t])) return false; // 在同一列或同一斜线上
 6  }
 7  return true;
 8}
 9
10void backtrack(int t, int n, int* x, int& sum)
11{
12  if(t == n) ++sum;
13  else
14  {
15    for(int i = 0; i < n; ++i)
16    {
17      x[t] = i;
18      if(place(t, x)) backtrack(t+1, n, x, sum);
19    }
20  }
21}

9.5. 实例

  • 全排列(含重复元素)。Hint:在交换第 \(i\) 个元素与第 \(j\) 个元素之前,要求数组的 \([i, j)\) 区间中的元素没有与第 \(j\) 个元素重复。

    https://blog.csdn.net/so_geili/article/details/71078945

    \(\color{darkgreen}{Code}\)

     1int cnt = 0; // 不同排列的个数
     2
     3//检查[from,to)之间的元素和第to号元素是否相同
     4bool isRepeat(int* A, int from, int to)
     5{
     6    for(int i = from; i < to; i++)
     7    {
     8        if(A[to] == A[i]) return true;
     9    }
    10    return false;
    11}
    12
    13void permutation(int* A, int t, int n)
    14{
    15    if(t == n)
    16    {
    17        cnt++;
    18        Output(A);
    19    }
    20    else
    21    {
    22        for(int j = t; j < n; j++)
    23        {
    24            if(!isRepeat(A, t, j))
    25            {
    26                swap(A[t], A[j]);
    27                permutation(A, t+1, n);
    28                swap(A[t], A[j]);
    29            }
    30        }
    31    }
    32}
    
  • Next Permutation 下一个排列。Hint:从后往前先找到第一个开始下降的数字 \(x\) (下标 \(i\) ),再从后往前找到第一个比 \(x\) 大的数 \(y\) (下标 \(j\) );交换 \(x\)\(y\) ;翻转区间 \([i+1, end]\)

    https://www.cnblogs.com/grandyang/p/4428207.html

    \(\color{darkgreen}{Code}\)

     1class Solution
     2{
     3public:
     4    void nextPermutation(vector<int> &num)
     5    {
     6        int i, j, n = num.size();
     7        for (i = n - 2; i >= 0; --i)
     8        {
     9            if (num[i + 1] > num[i])
    10            {
    11                for (j = n - 1; j > i; --j)
    12                {
    13                    if (num[j] > num[i]) break;
    14                }
    15                swap(num[i], num[j]);
    16                reverse(num.begin() + i + 1, num.end());
    17                return;
    18            }
    19        }
    20        reverse(num.begin(), num.end()); // 当前排列是最大的排列,则翻转为最小的排列
    21    }
    22};
    
  • 按字典序输出序列 \(1,2,...,n\) 的全排列。Hint:深度优先遍历。

    \(\color{darkgreen}{Code}\)

     1void DFS(int* arr, bool* used, int n, int t)
     2{
     3    if(t == n)
     4    {
     5        for(int i = 0; i < n; ++i) cout << arr[i] << ends;
     6        cout << endl;
     7        return;
     8    }
     9    for(int digit = 1; digit <= n; ++digit)
    10    {
    11        if(!used[digit - 1])
    12        {
    13            used[digit - 1] = true;
    14            arr[t] = digit;
    15            DFS(arr, used, n, t+1);
    16            used[digit - 1] = false;
    17        }
    18    }
    19}
    
  • [LeetCode] Permutation Sequence 输出序列 \(1,2,...,n\) 的第 \(k\) 个排列(字典序)。Hint:方法一,按字典序深度优先遍历;方法二,逐步缩小搜索范围,如: \(perm [ 1,2,3 ] = \{1 + perm [ 2,3 ] \} + \{2 + perm [ 1,3 ] \} + \{3 + perm [ 1,2 ] \}\)

    https://leetcode.com/problems/permutation-sequence/

    \(\color{darkgreen}{Code}\)

     1// https://leetcode.com/problems/permutation-sequence/discuss/22507/%22Explain-like-I'm-five%22-Java-Solution-in-O(n)
     2
     3class Solution
     4{
     5public:
     6    string getPermutation(int n, int k)
     7    {
     8        string nums = "";
     9        vector<int> factorial(n+1, 1);
    10        for(int i = 1; i <= n; ++i)
    11        {
    12            nums += to_string(i);
    13            factorial[i] = i;
    14        }
    15        partial_sum(factorial.begin(), factorial.end(), factorial.begin(), multiplies<int>()); // f(n) = n!, f(0) = 1
    16
    17        string res = "";
    18        while(n)
    19        {
    20            int id = (k - 1) / factorial[n-1]; // k - 1,下标从 0 开始
    21            res += nums[id];
    22            nums.erase(nums.begin() + id); // 得到 n - 1 个数的序列
    23            k -= id * factorial[n-1]; // 在 n - 1 个数的序列中继续查找第 k - id * factorial[n-1] 个排列
    24            --n;
    25        }
    26        return res;
    27    }
    28};
    
  • 输出序列 \(1,2,...,n\) 的所有子集(组合),共 \(2^n\) 组。Hint:方法一,回溯,二叉子集树;方法二,递归,序列每增加一个数,组合数增加一倍,增加的这些组合是在之前的组合的基础上插入该数得到的; 方法三,当 \(n < 32\) ,可以使用一个 int 型的变量 \(k\)\(1 \leqslant k \leqslant 2^n\) )来表示组合的状态,当该变量的二进制表示的第 \(i\) 位为 1,则表示当前组合中包含数字 \(i\)

    \(\color{darkgreen}{Code}\)

     1// 方法一,回溯
     2
     3void backtrack(int n, vector<int>& tmp, vector<vector<int>>& res)
     4{
     5  if (n == 0)
     6  {
     7    res.push_back(tmp);
     8    return;
     9  }
    10  backtrack(n - 1, tmp, res); // 不包含 n
    11  tmp.push_back(n);
    12  backtrack(n - 1, tmp, res); // 包含 n
    13  tmp.pop_back();
    14}
    15
    16vector<vector<int>> combination(int n)
    17{
    18  assert(n > 0);
    19  vector<vector<int>> res;
    20  vector<int> tmp;
    21  backtrack(n, tmp, res);
    22  return res;
    23}
    
     1// 方法二,递归
     2
     3void combinationRecursive(int n, vector<vector<int>>& res)
     4{
     5  if (n == 1)
     6  {
     7    res[1].push_back(1);
     8    return;
     9  }
    10
    11  combinationRecursive(n - 1, res);
    12
    13  int pre_num = pow(2, n - 1); // 在 1 ~ n-1 的组合上插入数字 n
    14  for (int i = 0; i < pre_num; ++i)
    15  {
    16    res[i + pre_num].push_back(n);
    17    for (int j = 0; j < res[i].size(); ++j)
    18    {
    19      res[i + pre_num].push_back(res[i][j]);
    20    }
    21  }
    22}
    23
    24vector<vector<int>> combination(int n)
    25{
    26  assert(n > 0);
    27  int num = pow(2, n);
    28  vector<vector<int>> res(num, vector<int>{});
    29  combinationRecursive(n, res);
    30  return res;
    31}
    
     1// 方法三,统计二进制中 1 的个数
     2
     3vector<vector<int>> combination(int n)
     4{
     5  assert(n > 0);
     6  int num = pow(2, n);
     7  vector<vector<int>> res(num, vector<int>{});
     8  int k = num - 1;
     9  while (k >= 0)
    10  {
    11    int pos = n - 1;
    12    while (pos >= 0)
    13    {
    14      if (k & (1 << pos)) res[k].push_back(pos + 1);
    15      --pos;
    16    }
    17    --k;
    18  }
    19  return res;
    20}
    
  • 输出整数集合的所有组合(包含重复元素)。Hint:统计每个元素的频率 \(f\) ,在组合过程中,该元素可取的个数最少为零个,最多为 \(f\) 个;回溯。

    https://leetcode.com/problems/subsets-ii/

    \(\color{darkgreen}{Code}\)

     1from collections import Counter
     2class Solution:
     3    def backtrack(self, ints: List[int], freqs: List[int], tmp: List[int], res:List[List[int]], t:int):
     4        if t == len(ints):
     5            res.append(tmp[:]) ## 注意:这里必须是添加tmp的副本到res中,否则随着tmp改变,res中的元素也会改变
     6            return
     7        for k in range(freqs[t] + 1):
     8            tmp.extend([ints[t]] * k)
     9            self.backtrack(ints, freqs, tmp, res, t+1)
    10            if tmp:
    11                for _ in range(k): tmp.pop()
    12    def subsetsWithDup(self, nums: List[int]) -> List[List[int]]:
    13        cnt = Counter(nums)
    14        ints, freqs = list(cnt.keys()), list(cnt.values()) ## python3 中需要把 dict_keys、dict_values 类型转换为 list
    15        res = []
    16        tmp = []
    17        self.backtrack(ints, freqs, tmp, res, 0)
    18        return res
    
  • [LeetCode] Distinct Subsequences II 子序列个数(含重复元素的组合数)。Hint:方法一,动态规划,设 \(dp[k]\) 是以 \(S[k]\) 结尾的子序列个数, 如果不考虑重复,则 \(dp[k] = dp[0] + dp[1] + \cdots + dp[k-1] + 1\) ,即在前面的子序列末尾追加 \(S[k]\) ,或 \(S[k]\) 单独构成的子序列( \(+1\) ); 然而要减掉以 \(S[k]\) 结尾的重复子序列: \(dp[k]\ -= dp[r],\ 0 \leqslant r < k \ \&\& \ S[k]=S[r]\) ; 方法二,回溯:设当前子序列集合最后一个元素的下标为 \(i\) ,在把当前字符(设下标为 \(t\) )加入子序列集合时, 需要考虑区间 \((i, t)\) (如果当前子序列集合为空,区间为 \([0, t)\) )内是否有 \(S[t]\) 的重复元素,如果有,则不能把 \(S[t]\) 插入当前子序列中,否则就造成重复;回溯方法严重超时。

    https://leetcode.com/problems/distinct-subsequences-ii

    \(\color{darkgreen}{Code}\)

     1// 方法一
     2
     3class Solution
     4{
     5public:
     6    int distinctSubseqII(string S)
     7    {
     8        if(S.empty()) return 0;
     9        vector<long long> dp(S.size(), 0);
    10        dp[0] = 1;
    11        for(size_t i = 1; i < S.size(); ++i)
    12        {
    13            dp[i] = accumulate(dp.begin(), dp.begin() + i, 1LL); // + 1,这里的 1LL 表示 long long int,默认的 int 型导致溢出,结果错误
    14            for(size_t k = 0; k < i; ++k)
    15            {
    16                if(S[k] == S[i]) // 减去重复
    17                {
    18                    dp[i] -= dp[k];
    19                    while(dp[i] < 0) dp[i] += 1000000007; // 减法操作可能会使得 dp[i] < 0
    20                }
    21            }
    22            dp[i] = dp[i] % 1000000007;
    23        }
    24        return accumulate(dp.begin(), dp.end(), 0LL) % 1000000007; // 0LL
    25    }
    26};
    
     1// 方法一改进型
     2
     3// 设 dp[l] 是以 S[l] 结尾的不重复子序列个数(定义与上面的方法一相同),
     4// 设 end[i] 是以字符 'a' + i 结尾的子序列个数,0 <= i < 26,S[l] = 'a' + i,
     5// 如果该字符出现在多个位置,如 {j,k,l},则 end[i] = dp[j] + dp[k] + dp[l],
     6// 由方法一可知:dp[l] = \sum_{m=0}^{l-1} dp[m] + 1 - dp[j] - dp[k],
     7// 因此 end[i] = \sum_{m=0}^{l-1} dp[m] + 1 = \sum_{n=0}^25 end[n] + 1
     8
     9class Solution
    10{
    11public:
    12    int distinctSubseqII(string S)
    13    {
    14        if(S.empty()) return 0;
    15        long long end[26] = {0};
    16        for(size_t i = 0; i < S.size(); ++i)
    17        {
    18            end[S[i] - 'a'] = accumulate(end, end + 26, 1LL) % 1000000007;
    19        }
    20        return accumulate(end, end + 26, 0LL) % 1000000007;
    21    }
    22};
    
     1// 方法二
     2
     3class Solution
     4{
     5public:
     6    int distinctSubseqII(string S)
     7    {
     8        if(S.empty()) return 0;
     9        int ans = 0;
    10        vector<int> subS; // 当前子序列集合
    11        DFS(S, 0, subS, ans);
    12        return ans;
    13    }
    14private:
    15    bool hasRepeat(string& S, vector<int>& subS, int t)
    16    {
    17        bool repeat = false;
    18        size_t i;
    19        if(subS.empty()) i = 0;
    20        else i = subS.back() + 1;
    21        for(; i < t; ++i)
    22        {
    23            if(S[i] == S[t])
    24            {
    25                repeat = true;
    26                break;
    27            }
    28        }
    29        return repeat;
    30    }
    31    void DFS(string& S, int t, vector<int>& subS, int &ans)
    32    {
    33        if(t == S.size())
    34        {
    35            if(!subS.empty()) ans = (ans + 1) % 1000000007;
    36            return;
    37        }
    38        DFS(S, t+1, subS, ans); // 当前子序列集合不包括 S[t]
    39        if(!hasRepeat(S, subS, t)) // 区间 (i, t) (或 [0, t))内不包括 S[t] 的重复字符,才可以把 S[t] 加入当前子序列集合
    40        {
    41            subS.push_back(t);
    42            DFS(S, t+1, subS, ans);
    43            subS.pop_back();
    44        }
    45    }
    46};
    
  • Word search 查找字符串路径。

    https://leetcode.com/problems/word-search/

    \(\color{darkgreen}{Code}\)

     1class Solution {
     2public:
     3    bool findPath(vector<vector<char>>& board, string word, bool** flag, int x, int y, int k)
     4    {
     5        if(k == word.size()) return true;
     6        for(int t = 0; t < 4; ++t)
     7        {
     8            int tx = x + mv[t][0];
     9            int ty = y + mv[t][1];
    10
    11            if(flag[tx+1][ty+1] && board[tx][ty] == word[k])
    12            {
    13                flag[tx+1][ty+1] = false; // 设置 flag
    14                if(findPath(board, word, flag, tx, ty, k+1)) return true;
    15                flag[tx+1][ty+1] = true; // flag 还原
    16            }
    17
    18        }
    19        return false;
    20    }
    21    bool exist(vector<vector<char>>& board, string word) {
    22        if(word == "") return true;
    23        if(board.size()==0) return false;
    24        int M = board.size();
    25        int N = board[0].size();
    26        bool** flag = new bool*[M+2]; // 设置一圈边界,标记为 false,后面访问 board 中的 4 个领域不用再判断是否越界;flag 的大小为 (M+2)x(N+2)
    27        for(int m = 0; m < M+2; ++m)
    28        {
    29            flag[m] = new bool[N+2];
    30            for(int n = 0; n < N+2; ++n)
    31            {
    32                if(m == 0 || m == M+1 || n == 0 || n == N+1) flag[m][n] = false;
    33                else flag[m][n] = true;
    34            }
    35        }
    36        bool EXIST = false;
    37        for(int i = 0; i < M; ++i)
    38        {
    39            for(int j = 0; j < N; ++j)
    40            {
    41                if(board[i][j] == word[0])
    42                {
    43                    flag[i+1][j+1] = false; // 注意: flag 的下标与 board 相差 1
    44                    if(findPath(board, word, flag, i, j, 1))
    45                    {
    46                        EXIST = true;
    47                        break; // 跳出第二重循环
    48                    }
    49                    flag[i+1][j+1] = true; // flag 还原
    50                }
    51            }
    52            if(EXIST) break; // 跳出第一重循环
    53        }
    54
    55        for(int m = 0; m < M+2; ++m) delete[] flag[m];
    56        delete[] flag;
    57
    58        return EXIST;
    59    }
    60private:
    61    static const int mv[4][2];
    62};
    63
    64const int Solution::mv[4][2] = {{-1,0},{0,-1},{0,1},{1,0}};
    
  • Knuth-Shuffle,公平的洗牌算法:生成每一种排列的概率都是 \(\frac{1}{n!}\)

    \(\color{darkgreen}{Code}\)

    1void shuffle(int* arr, int n)
    2{
    3  for(int i = n - 1; i >= 0; --i)
    4  {
    5    swap(arr[i], arr[rand()%(i+1)]);
    6  }
    7}