Eisen's Blog

© 2023. All rights reserved.

一道简单的图算法题 El Nino !

October 25, 2017

javaalgorithm

HackerEarch 是我最近发觉的一个不错的算法练习的平台。这里首先对这个平台做一些介绍,包括它与其他平台相比的优缺点。然后介绍一个简单的图算法题的解法。

HackerEarth 又一个编程练习网站

虽然自己的工作和复杂的算法关系不是很大,但是始终觉得这种对基础的熟练掌握和了解是做各种计算机行业相关工作的基础,然后自己也是总觉得自己在这方面有着一些兴趣,所有会经常找一些练习算法的平台刷刷题。之前有使用过 TopCoder、HackerRank、LeetCode 等。和 HackerRank 类似,HackerEarth 也是出自印度小哥之手。

虽然它本身有很多的问题,比如测试的执行不够稳定,比如很多题目的测试用例很少,比如 web page 存在各种展示的 bug等。但是 HackerEarth 有一个功能让我眼前一亮:它对各种题目进行了清晰的分门别类,让我这种练习者可以更针对性的进行训练。

另外一个对练习者友好的功能是它可以随意的查看一个 test cases 的输入和输出。如果我的一个提交出错了可以直接点击链接就能看到具体的输入和输出,虽然 HackerRank 也可以查看具体某一个 TestCase 但是需要自己通过在其网站上获取的点券购买,相比这里直接点击还是稍稍麻烦了一点。下面则是介绍一个我在这里写的一道图算法练习题。

El Nino ! 一个 DFS 算法题目

题目在这里 题目大概翻译一下如下:

你有一个包含 N 个节点的树 T,其中根为节点 1,还有一个长度为 M 的无重复整数的数组 A。

需要你去找到所有符合这样条件的三元组 (U, V, K),其中 V 属于 U 节点的子树,然后他们之间的距离恰好是 A[K]。而两个节点之间的距离就是他们之间的最短路径,即边的个数。

输入:第一行包含空格分割的两个整数 N 和 M 第二行包含 M 个空格分离的整数 第三行包含 N - 1 个空格分隔的整数,其中第 i 个整数是 i + 1 节点的父节点(注意,前者是第 i 个整数,而后者是指 节点 i + 1

输出:T 中产生的符合上述描述的三元组的个数

尝试 1

首先可是知道这是一个以 DFS (深度遍历)为中心的算法:每次深度遍历增加一个节点之后将其所有的祖先节点记录下来,并且检查数组 A 是否存在对应的距离。如果存在则计数 +1。于是有了第一个版本的提交:

import java.util.*;
 
class TestClass {
    private static int cnt = 0;
    
    public static void main(String args[] ) throws Exception {
        Scanner scan = new Scanner(System.in);
        int n = scan.nextInt();
        int m = scan.nextInt();
        
        // 用 HashSet 存储数组 A
        Set<Integer> lens = new HashSet<Integer>();
        for (int i = 0; i < m; i++) {
            lens.add(scan.nextInt());
        }
        
        // 用邻接表建立图
        LinkedList<Integer>[] g = new LinkedList[n];
        for (int i = 0; i < n; i++) {
            g[i] = new LinkedList<Integer>();
        }
        
        for (int i = 0; i < n - 1; i++) {
            int node = scan.nextInt() - 1;
            g[node].add(i + 1);
        }
        
        int[] dist = new int[n];
        Set<Integer> ancestors = new HashSet<>();
        ancestors.add(0);
 
        dfs(g, 0, dist, ancestors, lens);
        
        System.out.println(cnt);
    }
    
    public static void dfs(LinkedList<Integer>[] g, 
                           int u, 
                           int[] dist, 
                           Set<Integer> ancestors, 
                           Set<Integer> lens) {
        for (int v : g[u]) {
            dist[v] = dist[u] + 1;
            
            // 每遍历到一个新的节点就去检查其与所有祖先的距离
            for (int a : ancestors) {
                // 计算与祖先的距离
                int distance = dist[v] - dist[a];
                // 如果这个距离在 A 中存在就计数 +1
                if (lens.contains(distance)) {
                    cnt++;
                }
            }
            
            // 在下一次遍历之前添加已经处理完的节点
            ancestors.add(v);
            dfs(g, v, dist, ancestors, lens);
            // 在处理兄弟节点之前将当前节点从祖先列表前删除
            ancestors.remove(v);
        }
    }
}

然后只有三分之一左右的用例过了...其中一大部分超时,还有一部分报错。虽然我觉得这个算法的效率可能有点问题,但是我没想到居然有些是报错了。点开其中的一个报错的 Case 一看,发现自己没有处理 A 中包含 0 的情况。

尝试 2,考虑包含长度 0 的情况

知道错在哪里就好修改了。距离为零意味着三元组 [U, V, K] 中 U 和 V 应该一样。那么就是 N 的个数。所以我就做一个特殊检查:如果 A.contains(0) 那么 cnt += n

import java.util.*;
 
class TestClass {
    private static int cnt = 0;
    
    public static void main(String args[] ) throws Exception {
        Scanner scan = new Scanner(System.in);
        int n = scan.nextInt();
        int m = scan.nextInt();
        
        // 用 HashSet 存储数组 A
        Set<Integer> lens = new HashSet<Integer>();
        for (int i = 0; i < m; i++) {
            lens.add(scan.nextInt());
        }
        
        // 如果 A 中包含 0 计数 +n
        if (lens.contains(0)) {
            cnt += n;
        }
        
        // 用邻接表建立图
        LinkedList<Integer>[] g = new LinkedList[n];
        for (int i = 0; i < n; i++) {
            g[i] = new LinkedList<Integer>();
        }
        
        for (int i = 0; i < n - 1; i++) {
            int node = scan.nextInt() - 1;
            g[node].add(i + 1);
        }
        
        int[] dist = new int[n];
        Set<Integer> ancestors = new HashSet<>();
        ancestors.add(0);
 
        dfs(g, 0, dist, ancestors, lens);
        
        System.out.println(cnt);
    }
    
    // 和尝试 1 一样,就省略了
}

这次所有通过的就是真的通过了,但是还是有大量的超时和系统错误。大概分析一下,超时肯定是说算法的复杂度太高了,可以看到默认的 DFS 的复杂度是 O(V + E) 对于树来说,就是 O(V + V - 1) 而已。但是每次处理新的节点时都有一个过程:

// 每遍历到一个新的节点就去检查其与所有祖先的距离
for (int a : ancestors) {
    // 计算与祖先的距离
    int distance = dist[v] - dist[a];
    // 如果这个距离在 A 中存在就计数 +1
    if (lens.contains(distance)) {
        cnt++;
    }
}

是一个 O(V) 的复杂度,所以整体的复杂度有 O(V^2)

而报错则很有可能是栈溢出了,那么需要将目前的递归变成遍历。

尝试 3,将 O(v) 提升到 O(log(V))

计算所有祖先中存在的距离与 A 数组中的距离耗费的时间是 O(v),但是仔细想想,对于一个没有权重的树来说,其实每一个路径中,如果当前节点 v 有 k 个祖先,那么出现的距离一定是 0 到 k 了。我们不用每次都去检查有哪些祖先,其祖先的个数就是 d[v] 而已,那么上面的那部分可以改成如下的样子:

Set<Integer> set = range(dist[v]);
set.retainAll(lens);
cnt += set.size();

其中 range 为:

public static Set<Integer> range(int n) {
    Set<Integer> s = new HashSet<>();
    for (int i = 1; i <= n; i++) {
        s.add(i);
    }
    return s;
}

但是很遗憾,这样的复杂度并没有减少,retainAllrange 依然需要 O(V) 的复杂度。不过我觉得重点就是如何将目前这个 O(V) 做提升了。这个时候我想到了 binary search...我们当前的问题是在 range(dist[v]) 和数组 A 之间找并集,那么如果将 A 排序,我们找的就是在 A 中最大的那个 小于等于 dist[v] 的索引,这就是可以用 binary search 解决的问题了。于是,最终的提交如下:

import java.util.*;
 
class TestClass {
    
    public static void main(String args[] ) throws Exception {
        Scanner scan = new Scanner(System.in);
        int n = scan.nextInt();
        int m = scan.nextInt();
        
        int[] lens = new int[m];
        for (int i = 0; i < m; i++) {
            lens[i] = scan.nextInt();
        }
 
        Arrays.sort(lens);
        
        LinkedList<Integer>[] g = new LinkedList[n];
        for (int i = 0; i < n; i++) {
            g[i] = new LinkedList<Integer>();
        }
        
        for (int i = 0; i < n - 1; i++) {
            int node = scan.nextInt() - 1;
            g[node].add(i + 1);
        }
 
        long cnt = 0L;
        
        int[] dist = new int[n];
        // 同时用遍历替代递归,引入一个 stack 处理 DFS
        LinkedList<Integer> stack = new LinkedList<>();
        stack.add(0);
        int idx = bsearch(dist[0], lens);
        cnt += idx + 1;
 
        while (!stack.isEmpty()) {
            int u = stack.pop();
            for (int v : g[u]) {
                dist[v] = dist[u] + 1;
                idx = bsearch(dist[v], lens);
                // 索引 + 1 就是并集的个数,如果没找到 -1 + 1 就是 0
                cnt += idx + 1;
                stack.add(v);
            }
        }
        
        System.out.println(cnt);
    }
 
    // 找到最大的那个小于等 value 的索引
    public static int bsearch(int value, int[] lens) {
        int low = 0;
        int high = lens.length - 1;
 
        while (low < high) {
            int mid = (low + high + 1) / 2;
            if (lens[mid] <= value) {
                low = mid;
            } else {
                high = mid - 1;
            }
        }
 
        if (lens[low] > value) {
            return -1;
        }
        return low;
    }
}

这样修改之后就有了 O(Vlog(V)) 的复杂度,成功的 AC 了。

当然,其实中间还是有很多的坑,比如之前采用的计数是 int 现在改成了 long,再比如搜索索引之后就不再需要考虑有没有 0 距离的特殊情况了。

虽然看起来是一个不算难的 DFS 算法题,HackerEarth 也把它标记成简单,但是依然花了一番心思才优化成这个最终的版本。

参考资料

  1. binary search 以及其变种
  2. 图算法,深度优先遍历 MIT 在线课程
  3. HackerEarth EL NINO!