二叉树的拓展


拓展:最近公共祖先系列解题框架

如果说笔试的时候经常遇到各种动归回溯这类稍有难度的题目,那么面试会倾向于一些比较经典的问题,难度不算大,而且也比较实用。

本文就用 Git 引出一个经典的算法问题:最近公共祖先(Lowest Common Ancestor,简称 LCA)。

git pull 这个命令我们经常会用,它默认是使用 merge 方式将远端别人的修改拉到本地;如果带上参数 git pull -r,就会使用 rebase 的方式将远端修改拉到本地。

这二者最直观的区别就是:merge 方式合并的分支会看到很多「分叉」,而 rebase 方式合并的分支就是一条直线。但无论哪种方式,如果存在冲突,Git 都会检测出来并让你手动解决冲突。

那么问题来了,Git 是如何检测两条分支是否存在冲突的呢?

rebase 命令为例,比如下图的情况,我站在 dev 分支执行 git rebase master,然后 dev 就会接到 master 分支之上:

img

这个过程中,Git 是这么做的:

首先,找到这两条分支的最近公共祖先 LCA,然后从 master 节点开始,重演 LCAdev 几个 commit 的修改,如果这些修改和 LCAmastercommit 有冲突,就会提示你手动解决冲突,最后的结果就是把 dev 的分支完全接到 master 上面。

那么,Git 是如何找到两条不同分支的最近公共祖先的呢?这就是一个经典的算法问题了,下面我来由浅入深讲一讲。


寻找一个元素

先不管最近公共祖先问题,我请你实现一个简单的算法:

给你输入一棵没有重复元素的二叉树根节点 root 和一个目标值 val,请你写一个函数寻找树中值为 val 的节点。

函数签名如下:

def find(root: TreeNode, val: int) -> TreeNode:

这个函数应该很容易实现对吧,比如我这样写代码:

# 定义:在以 root 为根的二叉树中寻找值为 val 的节点
def find(root: TreeNode, val: int) -> TreeNode:
    # base case
    if not root:
        return None
    # 看看 root.val 是不是要找的
    if root.val == val:
        return root
    # root 不是目标节点,那就去左子树找
    left = find(root.left, val)
    if left:
        return left
    # 左子树找不着,那就去右子树找
    right = find(root.right, val)
    if right:
        return right
    # 实在找不到了
    return None

这段代码应该不用我多解释了,下面的可视化面板展示了这段代码的执行过程,你可以多次点击 这一行,即可展示出函数在二叉树上的搜索过程:

 算法可视化面板

下面我将基于这段代码做一些简单的改写,请你分析一下我的改动会造成什么影响。

首先,如果修改一下 return 的位置:

def find(root: TreeNode, val: int) -> TreeNode:
    if not root:
        return None
    # 前序位置
    if root.val == val:
        return root
    # root 不是目标节点,去左右子树寻找
    left = find(root.left, val)
    right = find(root.right, val)
    # 看看哪边找到了
    return left if left else right

这段代码也可以达到目的,但是实际运行的效率会低一些。

原因也很简单,如果你能够在左子树找到目标节点,还有没有必要去右子树找了?没有必要。但这段代码还是会去右子树找一圈,所以效率相对差一些。

下面的可视化面板展示了这段代码的执行过程,你可以多次点击 这一行,即可展示出函数在二叉树上的搜索过程,对比上面的可视化面板,这个函数会遍历二叉树的所有节点:

 算法可视化面板

那么,是不是说这种写法一定会遍历二叉树的所有节点呢?不一定,还有一个特殊情况,即要找的目标节点恰好就是根节点。

因为你是在前序位置判断 if (root.val == val) 的,所以这种特殊情况下函数可以直接结束。

更进一步,我把对 root.val 的判断从前序位置移动到后序位置:

def find(root: TreeNode, val: int) -> TreeNode:
    if root is None:
        return None
    # 先去左右子树寻找
    left = find(root.left, val)
    right = find(root.right, val)
    # 后序位置,看看 root 是不是目标节点
    if root.val == val:
        return root
    # root 不是目标节点,再去看看哪边的子树找到了
    return left if left is not None else right

这段代码相当于你先去左右子树找,最后才检查 root,依然可以到达目的,但是效率会进一步下降,因为这种写法必然会遍历二叉树的每一个节点

没办法,你是在后序位置判断,那么就算根节点就是目标节点,你也要去左右子树遍历完所有节点才能判断出来。

下面的可视化面板展示了这段代码的执行过程,你可以多次点击 这一行,即可展示出函数在二叉树上的搜索过程:

 算法可视化面板

最后,我再改一下题目,现在不让你找值为 val 的节点,而是寻找值为 val1 val2 的节点,函数签名如下:

def find(root: TreeNode, val1: int, val2: int) -> TreeNode:

为什么要写这样一个奇怪的 find 函数呢?因为最近公共祖先系列问题的解法都是把这个函数作为框架的

这和我们第一次实现的 find 函数基本上是一样的,而且你应该知道可以有多种写法,比方说我可以这样写代码:

# 定义:在以 root 为根的二叉树中寻找值为 val1 或 val2 的节点
def find(root, val1, val2):
    # base case
    if root is None:
        return None
    # 前序位置,看看 root 是不是目标值
    if root.val == val1 or root.val == val2:
        return root
    # 去左右子树寻找
    left = find(root.left, val1, val2)
    right = find(root.right, val1, val2)

    # 后序位置,已经知道左右子树是否存在目标值
    return left if left is not None else right

当然,这种写法会有重复遍历的问题,不过先不急着优化,最近公共祖先的一系列算法问题还就得基于这种写法展开。

下面一道一道题目来看,后文我用 LCA(Lowest Common Ancestor)作为最近公共祖先节点的缩写。

236. 二叉树的最近公共祖先

先来看看力扣第 236 题「二叉树的最近公共祖先」:

给你输入一棵不含重复值的二叉树,以及存在于树中的两个节点 pq,请你计算 pq 的最近公共祖先节点。

比如输入这样一棵二叉树:

img

如果 p 是节点 6q 是节点 7,那么它俩的 LCA 就是节点 5

img

当然,pq 本身也可能是 LCA,比如这种情况 q 本身就是 LCA 节点:

img

两个节点的最近公共祖先其实就是这两个节点向根节点的「延长线」的交汇点,那么对于任意一个节点,它怎么才能知道自己是不是 pq 的最近公共祖先?

如果一个节点能够在它的左右子树中分别找到 pq,则该节点为 LCA 节点

这就要用到之前实现的 find 函数了,只需在后序位置添加一个判断逻辑,即可改造成寻找最近公共祖先的解法代码:

class Solution:
    def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
        return self.find(root, p.val, q.val)

    # 在二叉树中寻找 val1 和 val2 的最近公共祖先节点
    def find(self, root: 'TreeNode', val1: int, val2: int) -> 'TreeNode':
        if root is None:
            return None
        # 前序位置
        if root.val == val1 or root.val == val2:
            # 如果遇到目标值,直接返回
            return root
        left = self.find(root.left, val1, val2)
        right = self.find(root.right, val1, val2)
        # 后序位置,已经知道左右子树是否存在目标值
        if left is not None and right is not None:
            # 当前节点是 LCA 节点
            return root
        
        return left if left is not None else right

find 函数的后序位置,如果发现 leftright 都非空,就说明当前节点是 LCA 节点,即解决了第一种情况:

img

find 函数的前序位置,如果找到一个值为 val1val2 的节点则直接返回,恰好解决了第二种情况:

img

因为题目说了 pq 一定存在于二叉树中(这点很重要),所以即便我们遇到 q 就直接返回,根本没遍历到 p,也依然可以断定 pq 底下,q 就是 LCA 节点。

下面这个可视化面板展示了这段代码的执行过程,你可以多次点击 这一行,即可展示出函数在二叉树上的搜索过程,你也可以自行修改测试用例玩一玩:

 算法可视化面板

结合可视化面板,我们也能发现一个优化的点,就是当我们在左子树找到目标 LCA 节点后,算法并没有结束,而是把右子树又遍历了一遍,这其实是没有必要的。

有前面的铺垫,你是不是想做类似这样的优化?

// root 不是目标节点,那就去左子树找
TreeNode left = find(root.left, val);
if (left != null) {
    return left;
}
// 左子树找不着,那就去右子树找
TreeNode right = find(root.right, val);
if (right != null) {
    return right;
}

不行的,因为我们本来就要同时去左子树和右子树寻找,来判断当前节点是不是 LCA

如果你非要优化,只能用一个外部变量来辅助判断是否已经找到答案,如果已经找到 LCA,则不再继续遍历二叉树:

class Solution:
    def __init__(self):
        # 用一个外部变量来记录是否已经找到 LCA 节点
        self.lca = None

    def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
        return self.find(root, p.val, q.val)

    def find(self, root: 'TreeNode', val1: int, val2: int) -> 'TreeNode':
        if root is None:
            return None
        # 如果已经找到 LCA 节点,直接返回
        if self.lca is not None:
            return None

        if root.val == val1 or root.val == val2:
            return root
        left = self.find(root.left, val1, val2)
        right = self.find(root.right, val1, val2)
        if left is not None and right is not None:
            # 当前节点是 LCA 节点,记录下来
            self.lca = root 
            return root
        
        return left if left is not None else right

这段算法的可视化面板如下,你可以多次点击 这一行,即可展示出函数在二叉树上的搜索过程,找到 LCA 节点后,算法就不再继续遍历右侧的子树了:

 算法可视化面板

这样,标准的最近公共祖先问题就解决了,接下来看看这个题目有什么变体。

1676. 二叉树的最近公共祖先 IV

比如力扣第 1676 题「二叉树的最近公共祖先 IV」:

依然给你输入一棵不含重复值的二叉树,但这次不是给你输入 pq 两个节点了,而是给你输入一个包含若干节点的列表 nodes(这些节点都存在于二叉树中),让你算这些节点的最近公共祖先。

函数签名如下:

def lowestCommonAncestor(root: TreeNode, nodes: List[TreeNode]) -> TreeNode:

比如还是这棵二叉树:

img

输入 nodes = [7,4,6],那么函数应该返回节点 5

看起来怪吓人的,实则解法逻辑是一样的,把刚才的代码逻辑稍加改造即可解决这道题:

class Solution:
    def lowestCommonAncestor(self, root: 'TreeNode', nodes: 'List[TreeNode]') -> 'TreeNode':
        # 将列表转化成哈希集合,便于判断元素是否存在
        values = set()
        for node in nodes:
            values.add(node.val)
        
        return self.find(root, values)
    
    def find(self, root: 'TreeNode', values: 'set') -> 'TreeNode':
        if root is None:
            return None
        # 前序位置
        if root.val in values:
            return root

        left = self.find(root.left, values)
        right = self.find(root.right, values)
        # 后序位置,已经知道左右子树是否存在目标值
        if left is not None and right is not None:
            # 当前节点是 LCA 节点
            return root
        
        return left if left is not None else right

类比一下上一道题应该不难理解这个解法。当找到 LCA 节点后,也可以提前停止算法,这个优化就留给你吧。

需要注意的是,这两道题的题目都明确告诉我们这些节点必定存在于二叉树中,如果没有这个前提条件,就需要修改代码了

1644. 二叉树的最近公共祖先 II

比如力扣第 1644 题「二叉树的最近公共祖先 II」:

给你输入一棵不含重复值的二叉树的,以及两个节点 pq,如果 pq 不存在于树中,则返回空指针,否则的话返回 pq 的最近公共祖先节点。

在解决标准的最近公共祖先问题时,我们在 find 函数的前序位置有这样一段代码:

// 前序位置
if (root.val == val1 || root.val == val2) {
    // 如果遇到目标值,直接返回
    return root;
}

我也进行了解释,因为 pq 都存在于树中,所以这段代码恰好可以解决最近公共祖先的第二种情况:

img

但对于这道题来说,pq 不一定存在于树中,所以你不能遇到一个目标值就直接返回,而应该对二叉树进行完全搜索(遍历每一个节点),如果发现 pq 不存在于树中,那么是不存在 LCA 的。

回想我在文章开头分析的几种 find 函数的写法,哪种写法能够对二叉树进行完全搜索来着?

这种:

def find(root: TreeNode, val: int) -> TreeNode:
    if root is None:
        return None
    # 先去左右子树寻找
    left = find(root.left, val)
    right = find(root.right, val)
    # 后序位置,看看 root 是不是目标节点
    if root.val == val:
        return root
    # root 不是目标节点,再去看看哪边的子树找到了
    return left if left is not None else right

那么解决这道题也是类似的,我们只需要把前序位置的判断逻辑放到后序位置即可:

class Solution:
    def __init__(self):
        # 用于记录 p 和 q 是否存在于二叉树中
        self.foundP = False
        self.foundQ = False

    def lowestCommonAncestor(self, root: TreeNode, p: TreeNode, q: TreeNode) -> TreeNode:
        res = self.find(root, p.val, q.val)
        if not self.foundP or not self.foundQ:
            return None
        # p 和 q 都存在二叉树中,才有公共祖先
        return res
        
    # 在二叉树中寻找 val1 和 val2 的最近公共祖先节点
    def find(self, root, val1, val2):
        if not root:
            return None
        left = self.find(root.left, val1, val2)
        right = self.find(root.right, val1, val2)
        
        # 后序位置,判断当前节点是不是 LCA 节点
        if left and right:
            return root
        
        # 后序位置,判断当前节点是不是目标值
        if root.val == val1 or root.val == val2:
            # 找到了,记录一下
            if root.val == val1:
                self.foundP = True
            if root.val == val2:
                self.foundQ = True
            return root

        return left if left else right

这样改造,对二叉树进行完全搜索,同时记录 pq 是否同时存在树中,从而满足题目的要求。

这段算法的可视化面板如下,我构造了一个 q 不在树中的场景,多次点击 即可查看函数搜索二叉树的过程:

 算法可视化面板

接下来,我们再变一变,如果让你在二叉搜索树中寻找 pq 的最近公共祖先,应该如何做呢?

235. 二叉搜索树的最近公共祖先

看力扣第 235 题「二叉搜索树的最近公共祖先」:

给你输入一棵不含重复值的二叉搜索树,以及存在于树中的两个节点 pq,请你计算 pq 的最近公共祖先节点。

把之前的解法代码复制过来肯定也可以解决这道题,但没有用到 BST「左小右大」的性质,显然效率不是最高的。

在标准的最近公共祖先问题中,我们要在后序位置通过左右子树的搜索结果来判断当前节点是不是 LCA

TreeNode left = find(root.left, val1, val2);
TreeNode right = find(root.right, val1, val2);

// 后序位置,判断当前节点是不是 LCA 节点
if (left != null && right != null) {
    return root;
}

但对于 BST 来说,根本不需要老老实实去遍历子树,由于 BST 左小右大的性质,将当前节点的值与 val1val2 作对比即可判断当前节点是不是 LCA

假设 val1 < val2,那么 val1 <= root.val <= val2 则说明当前节点就是 LCA;若 root.valval1 还小,则需要去值更大的右子树寻找 LCA;若 root.valval2 还大,则需要去值更小的左子树寻找 LCA

依据这个思路就可以写出解法代码:

class Solution:
    def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
        # 保证 val1 较小,val2 较大
        val1 = min(p.val, q.val)
        val2 = max(p.val, q.val)
        return self.find(root, val1, val2)

    # 在 BST 中寻找 val1 和 val2 的最近公共祖先节点
    def find(self, root: 'TreeNode', val1: int, val2: int) -> 'TreeNode':
        if root is None:
            return None
        if root.val > val2:
            # 当前节点太大,去左子树找
            return self.find(root.left, val1, val2)
        if root.val < val1:
            # 当前节点太小,去右子树找
            return self.find(root.right, val1, val2)
        # val1 <= root.val <= val2
        # 则当前节点就是最近公共祖先
        return root

上述代码的可视化面板如下,可以看到在 BST 中寻找最近公共祖先的过程非常快,不需要遍历整棵树:

 算法可视化面板

1650. 二叉树的最近公共祖先 III

再看最后一道最近公共祖先的题目吧,力扣第 1650 题「二叉树的最近公共祖先 III」,这次输入的二叉树节点比较特殊,包含指向父节点的指针。题目会给你输入一棵存在于二叉树中的两个节点 pq,请你返回它们的最近公共祖先。函数签名如下:

class Node:
    def __init__(self):
        self.val = None
        self.left = None
        self.right = None
        self.parent = None

# 函数签名
def lowestCommonAncestor(p: Node, q: Node) -> Node:

由于节点中包含父节点的指针,所以二叉树的根节点就没必要输入了。

这道题其实不是公共祖先的问题,而是单链表相交的问题,你把 parent 指针想象成单链表的 next 指针,题目就变成了:

给你输入两个单链表的头结点 pq,这两个单链表必然会相交,请你返回相交点。

img

我在前文 单链表的六大解题套路 中详细讲解过求链表交点的问题,具体思路在本文就不展开了,直接给出本题的解法代码:

class Solution:
    # 施展链表双指针技巧
    def lowestCommonAncestor(self, p: 'Node', q: 'Node') -> 'Node':
        a, b = p, q
        while a != b:
            # a 走一步,如果走到根节点,转到 q 节点
            if a is None:
                a = q
            else:
                a = a.parent
            # b 走一步,如果走到根节点,转到 p 节点
            if b is None:
                b = p
            else:
                b = b.parent
        return a

上述代码的可视化面板如下,可以多次点击 这一行代码查看两个指针交替移动的过程:

 算法可视化面板

至此,5 道最近公共祖先的题目就全部讲完了,前 3 道题目从一个基本的 find 函数衍生出解法,后 2 道比较特殊,分别利用了 BST 和单链表相关的技巧,希望本文对你有启发。

拓展:如何计算完全二叉树的节点数

如果让你数一下一棵普通二叉树有多少个节点,这很简单,只要在二叉树的遍历框架上加一点代码就行了。

但是,力扣第第 222 题「完全二叉树的节点个数」给你一棵完全二叉树,让你计算它的节点个数,你会不会?算法的时间复杂度是多少?

这个算法的时间复杂度应该是$ O(logN∗logN),如果你心中的算法没有达到这么高效,那么本文就是给你写的。

关于「完全二叉树」和「满二叉树」等名词的定义,可以参考基础知识章节的 二叉树基础

一、思路分析

现在回归正题,如何求一棵完全二叉树的节点个数呢?

# 输入一棵完全二叉树,返回节点总数
def countNodes(root: TreeNode) -> int:

如果是一个普通二叉树,显然只要向下面这样遍历一边即可,时间复杂度 $O(N)$:

def countNodes(root: TreeNode) -> int:
    if root == None:
        return 0
    return 1 + countNodes(root.left) + countNodes(root.right)

那如果是一棵二叉树,节点总数就和树的高度呈指数关系:

def countNodes(root: TreeNode) -> int:
    h = 0
    # 计算树的高度
    while root:
        root = root.left
        h += 1
    # 节点总数就是 2^h - 1
    return 2 ** h - 1

完全二叉树比普通二叉树特殊,但又没有满二叉树那么特殊,计算它的节点总数,可以说是普通二叉树和完全二叉树的结合版,先看代码:

class Solution:
    def countNodes(self, root: TreeNode) -> int:
        l = root
        r = root
        hl = 0
        hr = 0
        # 沿最左侧和最右侧分别计算高度
        while l is not None:
            l = l.left
            hl += 1
        while r is not None:
            r = r.right
            hr += 1
        # 如果左右侧计算的高度相同,则是一棵满二叉树
        if hl == hr:
            return pow(2, hl) - 1
        # 如果左右侧的高度不同,则按照普通二叉树的逻辑计算
        return 1 + self.countNodes(root.left) + self.countNodes(root.right)
 算法可视化面板

结合刚才针对满二叉树和普通二叉树的算法,上面这段代码应该不难理解,就是一个结合版,但是其中降低时间复杂度的技巧是非常微妙的

二、复杂度分析

开头说了,这个算法的时间复杂度是 $O(log⁡N×log⁡N)$,这是怎么算出来的呢?

直觉感觉好像最坏情况下是 $O(N×log⁡N)$ 吧,因为之前的 while 需要$ log⁡N$的时间,最后要 $O(N)$ 的时间向左右子树递归:

return 1 + countNodes(root.left) + countNodes(root.right);

关键点在于,这两个递归只有一个会真的递归下去,另一个一定会触发 hl == hr 而立即返回,不会递归下去

为什么呢?原因如下:

一棵完全二叉树的两棵子树,至少有一棵是满二叉树

img

看图就明显了吧,由于完全二叉树的性质,其子树一定有一棵是满的,所以一定会触发 hl == hr,只消耗 $O*(log*N)$ 的复杂度而不会继续递归。

综上,算法的递归深度就是树的高度 $ O(logN) $,每次递归所花费的时间就是 while 循环,需要 $O(logN)$,所以总体的时间复杂度是$O(logN×logN)$。

所以说,「完全二叉树」这个概念还是有它存在的原因的,不仅适用于数组实现二叉堆,而且连计算节点总数这种看起来简单的操作都有高效的算法实现。

拓展:惰性展开多叉树

一、题目描述

这是力扣第 341 题「扁平化嵌套列表迭代器」:

341. 扁平化嵌套列表迭代器 | 力扣 | LeetCode |  🟠

给你一个嵌套的整数列表 nestedList 。每个元素要么是一个整数,要么是一个列表;该列表的元素也可能是整数或者是其他列表。请你实现一个迭代器将其扁平化,使之能够遍历这个列表中的所有整数。

实现扁平迭代器类 NestedIterator 

  • NestedIterator(List<NestedInteger> nestedList) 用嵌套列表 nestedList 初始化迭代器。
  • int next() 返回嵌套列表的下一个整数。
  • boolean hasNext() 如果仍然存在待迭代的整数,返回 true ;否则,返回 false 

你的代码将会用下述伪代码检测:

initialize iterator with nestedList
res = []
while iterator.hasNext()
    append iterator.next() to the end of res
return res

如果 res 与预期的扁平化列表匹配,那么你的代码将会被判为正确。

示例 1:

输入:nestedList = [[1,1],2,[1,1]]
输出:[1,1,2,1,1]
解释:通过重复调用 next 直到 hasNext 返回 false,next 返回的元素的顺序应该是: [1,1,2,1,1]

示例 2:

输入:nestedList = [1,[4,[6]]]
输出:[1,4,6]
解释:通过重复调用 next 直到 hasNext 返回 false,next 返回的元素的顺序应该是: [1,4,6]

提示:

  • 1 <= nestedList.length <= 500
  • 嵌套列表中的整数值在范围 [-106, 106] 
题目来源:力扣 341. 扁平化嵌套列表迭代器

我们的算法会被输入一个 NestedInteger 列表,我们需要做的就是写一个迭代器类 NestedIterator,将这个带有嵌套结构 NestedInteger 的列表「拍平」:

class NestedIterator:
    def __init__(self, nestedList: List[NestedInteger]):
        # 构造器输入一个 NestedInteger 列表
        pass

    # 返回下一个整数
    def next(self) -> int:
        pass

    # 是否还有下一个元素?
    def hasNext(self) -> bool:
        pass

我们写的这个 NestedIterator 类会被这样调用,先调用 hasNext 方法,后调用 next 方法

NestedIterator i = new NestedIterator(nestedList);
while (i.hasNext())
    print(i.next());

学过设计模式的朋友应该知道,迭代器也是设计模式的一种,目的就是为调用者屏蔽底层数据结构的细节,简单地通过 hasNextnext 方法有序地进行遍历。

为什么说这个题目很有启发性呢?因为我最近在用一款类似印象笔记的软件,叫做 Notion(挺有名的)。这个软件的一个亮点就是「万物皆 block」,比如说标题、页面、表格都是 block。有的 block 甚至可以无限嵌套,这就打破了传统笔记本「文件夹」->「笔记本」->「笔记」的三层结构。

回想这个算法问题,NestedInteger 结构实际上也是一种支持无限嵌套的结构,而且可以同时表示整数和列表两种不同类型,我想 Notion 的核心数据结构 block 估计也是这样的一种设计思路。

那么话说回来,对于这个算法问题,我们怎么解决呢?NestedInteger 结构可以无限嵌套,怎么把这个结构「打平」,为迭代器的调用者屏蔽底层细节,得到扁平化的输出呢?

二、解题思路

显然,NestedInteger 这个神奇的数据结构是问题的关键,不过题目专门提醒我们不要尝试去实现它,也不要去猜测它的实现。

为什么?凭什么?是不是题目在误导我?是不是我进行推测之后,这道题就不攻自破了

你不让推测,我就偏偏要去推测!我反手就把 NestedInteger 这个结构给实现出来:

# 定义嵌套整型类
class NestedInteger:
    # 嵌套整型类型,通过一个整数初始化
    def __init__(self, val: int = None, lst: List['NestedInteger'] = None):
        # 当前嵌套整型类的整型值
        self.val = val
        # 当前嵌套整型类嵌套的整型类列表
        self.lst = lst

    # 如果其中存的是一个整数,则返回 true,否则返回 false
    def isInteger(self) -> bool:
        return self.val is not None

    # 如果其中存的是一个整数,则返回这个整数,否则返回 null
    def getInteger(self) -> int:
        return self.val

    # 如果其中存的是一个列表,则返回这个列表,否则返回 null
    def getList(self) -> List['NestedInteger']:
        return self.lst

嗯,其实这个实现也不难嘛,写出来之后,我不禁翻出前文 多叉树基础及遍历,发现这玩意儿竟然……

class NestedInteger:
    def __init__(self):
        self.val = None
        self.list = []

# 基本的 N 叉树节点
class TreeNode:
    def __init__(self, val):
        self.val = val
        self.children = []

这玩意儿不就是棵 N 叉树吗?叶子节点是 Integer 类型,其 val 字段非空;其他节点都是 List<NestedInteger> 类型,其 val 字段为空,但是 list 字段非空,装着孩子节点

比如说输入是 [[1,1],2,[1,1]],其实就是如下树状结构:

img

好的,刚才题目说什么来着?把一个 NestedInteger 扁平化对吧?这不就等价于遍历一棵 N 叉树的所有「叶子节点」吗?我把所有叶子节点都拿出来,不就可以作为迭代器进行遍历了吗?

N 叉树的遍历怎么整?我又不禁翻出前文 多叉树遍历框架

def traverse(root: TreeNode):
    if root is None:
        return
    for child in root.children:
        traverse(child)

这个框架可以遍历所有节点,而我们只对整数型的 NestedInteger 感兴趣,也就是我们只想要「叶子节点」,所以 traverse 函数只要在到达叶子节点的时候把 val 加入结果列表即可:

class NestedIterator:
    def __init__(self, nestedList: List[NestedInteger]):
        # 存放将 nestedList 打平的结果
        result = []
        for node in nestedList:
            # 以每个节点为根遍历
            self.traverse(node, result)
        self.index = 0
        self.result = result

    def traverse(self, root, result):
        if root.isInteger():
            # 到达叶子节点
            result.append(root.getInteger())
            return
        # 遍历框架
        for child in root.getList():
            self.traverse(child, result)

    def next(self) -> int:
        res = self.result[self.index]
        self.index += 1
        return res

    def hasNext(self) -> bool:
        return self.index < len(self.result)

这样,我们就把原问题巧妙转化成了一个 N 叉树的遍历问题,并且得到了解法。

三、进阶思路

以上解法虽然可以通过,但是在面试中,也许是有瑕疵的。

我们的解法中,一次性算出了所有叶子节点的值,全部装到 result 列表,也就是内存中,nexthasNext 方法只是在对 result 列表做迭代。如果输入的规模非常大,构造函数中的计算就会很慢,而且很占用内存。

一般的迭代器求值应该是「惰性的」,也就是说,如果你要一个结果,我就算一个(或是一小部分)结果出来,而不是一次把所有结果都算出来。

如果想做到这一点,使用递归函数进行 DFS 遍历肯定是不行的,而且我们其实只关心「叶子节点」,所以传统的 BFS 算法也不行。实际的思路很简单:

调用 hasNext 时,如果 nestedList 的第一个元素是列表类型,则不断展开这个元素,直到第一个元素是整数类型

仔细想一下这个过程应该就能理解了,一次只展开一个最内层的 nestedList,不会一次性把所有 nestedList 展开,相当于惰性的 DFS 遍历。

由于调用 next 方法之前一定会调用 hasNext 方法,这就可以保证每次调用 next 方法的时候第一个元素是整数型,直接返回并删除第一个元素即可。

看一下代码:

class NestedIterator:
    def __init__(self, nestedList: [NestedInteger]):
        # 不直接用 nestedList 的引用,是因为不能确定它的底层实现
        # 必须保证是 LinkedList,否则下面的 addFirst 会很低效
        self.list = collections.deque(nestedList)

    def next(self) -> int:
        # hasNext 方法保证了第一个元素一定是整数类型
        return self.list.popleft().getInteger()

    def hasNext(self) -> bool:
        # 循环拆分列表元素,直到列表第一个元素是整数类型
        while self.list and not self.list[0].isInteger():
            # 当列表开头第一个元素是列表类型时,进入循环
            first = self.list.popleft().getList()
            # 将第一个列表打平并按顺序添加到开头
            for i in range(len(first) - 1, -1, -1):
                self.list.appendleft(first[i])
        return bool(self.list)

以这种方法,符合迭代器惰性求值的特性,是比较好的解法。

拓展:归并排序详解及应用

一直都有很多读者说,想让我用框架思维讲一讲基本的排序算法,我觉得确实得讲讲,毕竟学习任何东西都讲求一个融会贯通,只有对其本质进行比较深刻的理解,才能运用自如。

本文就先讲归并排序,给一套代码模板,然后讲讲它在算法问题中的应用。阅读本文前我希望你读过前文 手把手刷二叉树(纲领篇)

我在讲二叉树的时候,提了一嘴归并排序,说归并排序就是二叉树的后序遍历,当时就有很多读者留言说醍醐灌顶。

知道为什么很多读者遇到递归相关的算法就觉得烧脑吗?因为还处在「看山是山,看水是水」的阶段。

就说归并排序吧,如果给你看代码,让你脑补一下归并排序的过程,你脑子里会出现什么场景?

这是一个数组排序算法,所以你脑补一个数组的 GIF,在那一个个交换元素?如果是这样的话,那格局就低了。

但如果你脑海中浮现出的是一棵二叉树,甚至浮现出二叉树后序遍历的场景,那格局就高了,大概率掌握了我经常强调的 框架思维,用这种抽象能力学习算法就省劲多了。

那么,归并排序明明就是一个数组算法,和二叉树有什么关系?接下来我就具体讲讲。


算法思路

就这么说吧,所有递归的算法,你甭管它是干什么的,本质上都是在遍历一棵(递归)树,然后在节点(前中后序位置)上执行代码,你要写递归算法,本质上就是要告诉每个节点需要做什么

你看归并排序的代码框架:

# 定义:排序 nums[lo..hi]
def sort(nums: List[int], lo: int, hi: int) -> None:
    if lo == hi:
        return
    mid = (lo + hi) // 2
    # 利用定义,排序 nums[lo..mid]
    sort(nums, lo, mid)
    # 利用定义,排序 nums[mid+1..hi]
    sort(nums, mid + 1, hi)

    # ***** 后序位置 *****
    # 此时两部分子数组已经被排好序
    # 合并两个有序数组,使 nums[lo..hi] 有序
    merge(nums, lo, mid, hi)

# 将有序数组 nums[lo..mid] 和有序数组 nums[mid+1..hi]
# 合并为有序数组 nums[lo..hi]
def merge(nums: List[int], lo: int, mid: int, hi: int) -> None:
    pass

看这个框架,也就明白那句经典的总结:归并排序就是先把左半边数组排好序,再把右半边数组排好序,然后把两半数组合并。

上述代码和二叉树的后序遍历很像:

# 二叉树遍历框架 
def traverse(root: TreeNode) -> None: 
    if root is None:
        return
    traverse(root.left)
    traverse(root.right)
    # 后序位置
    print(root.val)

再进一步,你联想一下求二叉树的最大深度的算法代码:

# 定义:输入根节点,返回这棵二叉树的最大深度
def maxDepth(root: TreeNode) -> int:
	if not root:
		return 0
	# 利用定义,计算左右子树的最大深度
	leftMax = maxDepth(root.left)
	rightMax = maxDepth(root.right)
	# 整棵树的最大深度等于左右子树的最大深度取最大值,
	# 然后再加上根节点自己
	res = max(leftMax, rightMax) + 1

	return res

是不是更像了?

前文 手把手刷二叉树(纲领篇) 说二叉树问题可以分为两类思路,一类是遍历一遍二叉树的思路,另一类是分解问题的思路,根据上述类比,显然归并排序利用的是分解问题的思路(分治算法)。

归并排序的过程可以在逻辑上抽象成一棵二叉树,树上的每个节点的值可以认为是 nums[lo..hi],叶子节点的值就是数组中的单个元素

img

然后,在每个节点的后序位置(左右子节点已经被排好序)的时候执行 merge 函数,合并两个子节点上的子数组:

img

这个 merge 操作会在二叉树的每个节点上都执行一遍,执行顺序是二叉树后序遍历的顺序。

后序遍历二叉树大家应该已经烂熟于心了,就是下图这个遍历顺序:

img

结合上述基本分析,我们把 nums[lo..hi] 理解成二叉树的节点,sort 函数理解成二叉树的遍历函数,整个归并排序的执行过程就是以下 GIF 描述的这样:

img

这样,归并排序的核心思路就分析完了,接下来只要把思路翻译成代码就行。

代码实现

只要拥有了正确的思维方式,理解算法思路是不困难的,但把思路实现成代码,也很考验一个人的编程能力

毕竟算法的时间复杂度只是一个理论上的衡量标准,而算法的实际运行效率要考虑的因素更多,比如应该避免内存的频繁分配释放,代码逻辑应尽可能简洁等等。

这里我参考《算法 4》这本书中归并排序代码给出归并排序的代码实现:

class Merge:

    # 用于辅助合并有序数组
    temp = []

    @staticmethod
    def sort(nums):
        # 先给辅助数组开辟内存空间
        Merge.temp = [0] * len(nums)
        # 排序整个数组(原地修改)
        Merge._sort(nums, 0, len(nums) - 1)

    # 定义:将子数组 nums[lo..hi] 进行排序
    @staticmethod
    def _sort(nums, lo, hi):
        if lo == hi:
            # 单个元素不用排序
            return
        # 这样写是为了防止溢出,效果等同于 (hi + lo) / 2
        mid = lo + (hi - lo) // 2
        # 先对左半部分数组 nums[lo..mid] 排序
        Merge._sort(nums, lo, mid)
        # 再对右半部分数组 nums[mid+1..hi] 排序
        Merge._sort(nums, mid + 1, hi)
        # 将两部分有序数组合并成一个有序数组
        Merge.merge(nums, lo, mid, hi)

    # 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
    @staticmethod
    def merge(nums, lo, mid, hi):
        # 先把 nums[lo..hi] 复制到辅助数组中
        # 以便合并后的结果能够直接存入 nums
        for i in range(lo, hi+1):
            Merge.temp[i] = nums[i]

        # 数组双指针技巧,合并两个有序数组
        i, j = lo, mid + 1
        for p in range(lo, hi+1):
            if i == mid + 1:
                # 左半边数组已全部被合并
                nums[p] = Merge.temp[j]
                j += 1
            elif j == hi + 1:
                # 右半边数组已全部被合并
                nums[p] = Merge.temp[i]
                i += 1
            elif Merge.temp[i] > Merge.temp[j]:
                nums[p] = Merge.temp[j]
                j += 1
            else:
                nums[p] = Merge.temp[i]
                i += 1

有了之前的铺垫,这里只需要着重讲一下这个 merge 函数。

sort 函数对 nums[lo..mid]nums[mid+1..hi] 递归排序完成之后,我们没有办法原地把它俩合并,所以需要 copy 到 temp 数组里面,然后通过类似于前文 单链表的六大技巧 中合并有序链表的双指针技巧将 nums[lo..hi] 合并成一个有序数组:

img

注意我们不是在 merge 函数执行的时候 new 辅助数组,而是提前把 temp 辅助数组 new 出来了,这样就避免了在递归中频繁分配和释放内存可能产生的性能问题。

贴一个归并排序过程的可视化动画,方便大家理解算法运行的过程:

 算法可视化面板

复杂度分析

再说一下归并排序的时间复杂度,虽然大伙儿应该都知道是 O(NlogN),但不见得所有人都知道这个复杂度怎么算出来的。

前文 动态规划详解 说过递归算法的复杂度计算,就是子问题个数 x 解决一个子问题的复杂度。对于归并排序来说,时间复杂度显然集中在 merge 函数遍历 nums[lo..hi] 的过程,但每次 merge 输入的 lohi 都不同,所以不容易直观地看出时间复杂度。

merge 函数到底执行了多少次?每次执行的时间复杂度是多少?总的时间复杂度是多少?这就要结合之前画的这幅图来看:

img

执行的次数是二叉树节点的个数,每次执行的复杂度就是每个节点代表的子数组的长度,所以总的时间复杂度就是整棵树中「数组元素」的个数

所以从整体上看,这个二叉树的高度是 logN,其中每一层的元素个数就是原数组的长度 N,所以总的时间复杂度就是 O(NlogN)。

912. 排序数组

力扣第 912 题「排序数组」就是让你对数组进行排序,我们可以直接套用归并排序代码模板:

class Solution:
    def sortArray(self, nums: List[int]) -> List[int]:
        Merge.sort(nums)
        return nums

class Merge:
    # 见上文

315. 计算右侧小于当前元素的个数

除了最基本的排序问题,归并排序还可以用来解决力扣第 315 题「计算右侧小于当前元素的个数」:

315. 计算右侧小于当前元素的个数 | 力扣 | LeetCode |  🔴

给你一个整数数组 nums ,按要求返回一个新数组 counts 。数组 counts 有该性质: counts[i] 的值是  nums[i] 右侧小于 nums[i] 的元素的数量。

示例 1:

输入:nums = [5,2,6,1]
输出:[2,1,1,0] 
解释:
5 的右侧有 2 个更小的元素 (2 和 1)
2 的右侧仅有 1 个更小的元素 (1)
6 的右侧有 1 个更小的元素 (1)
1 的右侧有 0 个更小的元素

示例 2:

输入:nums = [-1]
输出:[0]

示例 3:

输入:nums = [-1,-1]
输出:[0,0]

提示:

  • 1 <= nums.length <= 105
  • -104 <= nums[i] <= 104
题目来源:力扣 315. 计算右侧小于当前元素的个数

我用比较数学的语言来描述一下(方便和后续类似题目进行对比),题目让你求出一个 count 数组,使得:

count[i] = COUNT(j) where j > i and nums[j] < nums[i]

拍脑袋的暴力解法就不说了,嵌套 for 循环,平方级别的复杂度。

这题和归并排序什么关系呢,主要在 merge 函数,我们在使用 merge 函数合并两个有序数组的时候,其实是可以知道一个元素 nums[i] 后边有多少个元素比 nums[i] 小的

具体来说,比如这个场景:

img

这时候我们应该把 temp[i] 放到 nums[p] 上,因为 temp[i] < temp[j]

但就在这个场景下,我们还可以知道一个信息:5 后面比 5 小的元素个数就是 左闭右开区间 [mid + 1, j) 中的元素个数,即 2 和 4 这两个元素:

img

换句话说,在对 nums[lo..hi] 合并的过程中,每当执行 nums[p] = temp[i] 时,就可以确定 temp[i] 这个元素后面比它小的元素个数为 j - mid - 1

当然,nums[lo..hi] 本身也只是一个子数组,这个子数组之后还会被执行 merge,其中元素的位置还是会改变。但这是其他递归节点需要考虑的问题,我们只要在 merge 函数中做一些手脚,叠加每次 merge 时记录的结果即可。

发现了这个规律后,我们只要在 merge 中添加两行代码即可解决这个问题,看解法代码:

class Solution:
    class Pair:
        def __init__(self, val, id):
            # 记录数组的元素值
            self.val = val
            # 记录元素在数组中的原始索引
            self.id = id

    # 归并排序所用的辅助数组
    temp = []
    # 记录每个元素后面比自己小的元素个数
    count = []

    # 主函数
    def countSmaller(self, nums):
        n = len(nums)
        self.count = [0] * n
        self.temp = [self.Pair(0, 0)] * n
        arr = []
        # 记录元素原始的索引位置,以便在 count 数组中更新结果
        for i in range(n):
            arr.append(self.Pair(nums[i], i))

        # 执行归并排序,本题结果被记录在 count 数组中
        self.sort(arr, 0, n - 1)

        res = []
        for c in self.count:
            res.append(c)
        return res

    # 归并排序
    def sort(self, arr, lo, hi):
        if lo == hi:
            return
        mid = lo + (hi - lo) // 2
        self.sort(arr, lo, mid)
        self.sort(arr, mid + 1, hi)
        self.merge(arr, lo, mid, hi)

    # 合并两个有序数组
    def merge(self, arr, lo, mid, hi):
        for i in range(lo, hi + 1):
            self.temp[i] = arr[i]

        i, j = lo, mid + 1
        for p in range(lo, hi + 1):
            if i == mid + 1:
                arr[p] = self.temp[j]
                j += 1
            elif j == hi + 1:
                arr[p] = self.temp[i]
                i += 1
                # 更新 count 数组
                self.count[arr[p].id] += j - mid - 1
            elif self.temp[i].val > self.temp[j].val:
                arr[p] = self.temp[j]
                j += 1
            else:
                arr[p] = self.temp[i]
                i += 1
                # 更新 count 数组
                self.count[arr[p].id] += j - mid - 1

因为在排序过程中,每个元素的索引位置会不断改变,所以我们用一个 Pair 类封装每个元素及其在原始数组 nums 中的索引,以便 count 数组记录每个元素之后小于它的元素个数。

接下来我们再看几道原理类似的题目,都是通过给归并排序的 merge 函数加一些私货完成目标。

493. 翻转对

看一下力扣第 493 题「翻转对」:

493. 翻转对 | 力扣 | LeetCode |  🔴

给定一个数组 nums ,如果 i < j 且 nums[i] > 2*nums[j] 我们就将 (i, j) 称作一个重要翻转对

你需要返回给定数组中的重要翻转对的数量。

示例 1:

输入: [1,3,2,3,1]
输出: 2

示例 2:

输入: [2,4,3,5,1]
输出: 3

注意:

  1. 给定数组的长度不会超过50000
  2. 输入数组中的所有数字都在32位整数的表示范围内。
题目来源:力扣 493. 翻转对

我把这道题换个表述方式,你注意和上一道题目对比:

请你先求出一个 count 数组,其中:

count[i] = COUNT(j) where j > i and nums[i] > 2*nums[j]

然后请你求出这个 count 数组中所有元素的和。

你看,这样说其实和题目是一个意思,而且和上一道题非常类似,只不过上一题求的是 nums[i] > nums[j],这里求的是 nums[i] > 2*nums[j] 罢了。

所以解题的思路当然还是要在 merge 函数中做点手脚,当 nums[lo..mid]nums[mid+1..hi] 两个子数组完成排序后,对于 nums[lo..mid] 中的每个元素 nums[i],去 nums[mid+1..hi] 中寻找符合条件的 nums[j] 就行了。

看一下我们对上一题 merge 函数的改造:

# 记录「翻转对」的个数
count = 0

# 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
def merge(nums, lo, mid, hi):
    global count
    temp = nums.copy()

    # 在合并有序数组之前,加点私货
    for i in range(lo, mid + 1):
        # 对于左半边的每个 nums[i],都去右半边寻找符合条件的元素
        for j in range(mid + 1, hi + 1):
            # nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
            if int(nums[i]) > 2 * int(nums[j]):
                count += 1

    # 数组双指针技巧,合并两个有序数组
    i, j = lo, mid + 1
    for p in range(lo, hi + 1):
        if (i == mid + 1):
            nums[p] = temp[j]
            j += 1
        elif (j == hi + 1):
            nums[p] = temp[i]
            i += 1
        elif temp[i] > temp[j]:
            nums[p] = temp[j]
            j += 1
        else:
            nums[p] = temp[i]
            i += 1

不过呢,这样修改代码会超时,毕竟额外添加了一个嵌套 for 循环。怎么进行优化呢,注意子数组 nums[lo..mid] 是排好序的,也就是 nums[i] <= nums[i+1]

所以,对于 nums[i], lo <= i <= mid,我们在找到的符合 nums[i] > 2*nums[j]nums[j], mid+1 <= j <= hi,也必然也符合 nums[i+1] > 2*nums[j]

换句话说,我们不用每次都傻乎乎地去遍历整个 nums[mid+1..hi],只要维护一个开区间边界 end,维护 nums[mid+1..end-1] 是符合条件的元素即可

看最终的解法代码:

class Solution:
    def __init__(self):
        # 记录「翻转对」的个数
        self.count = 0
        self.temp = []

    def reversePairs(self, nums) -> int:
        # 执行归并排序
        self.temp = [0]*len(nums)
        self.sort(nums, 0, len(nums) - 1)
        return self.count

    def sort(self, nums, lo, hi):
        # 归并排序
        if lo >= hi:
            return

        mid = lo + (hi - lo) // 2
        self.sort(nums, lo, mid)
        self.sort(nums, mid + 1, hi)
        self.merge(nums, lo, mid, hi)

    def merge(self, nums, lo, mid, hi):
        for i in range(lo, hi+1):
            self.temp[i] = nums[i]
        
        # 维护左闭右开区间 [mid+1, end) 中的元素乘 2 小于 nums[i]
        # 为什么 end 是开区间?因为这样的话可以保证初始区间 [mid+1, mid+1) 是一个空区间
        end = mid + 1
        for i in range(lo, mid+1):
            # nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
            while end <= hi and nums[i] > nums[end] * 2:
                end += 1

            self.count += end - (mid + 1)

        # 数组双指针技巧,合并两个有序数组
        i, j = lo, mid + 1
        for p in range(lo, hi+1):
            if i == mid + 1:
                nums[p] = self.temp[j]
                j += 1
            elif j == hi + 1:
                nums[p] = self.temp[i]
                i += 1
            elif self.temp[i] > self.temp[j]:
                nums[p] = self.temp[j]
                j += 1
            else:
                nums[p] = self.temp[i]
                i += 1

327. 区间和的个数

如果你能够理解这道题目,我们最后来看一道难度更大的题目,力扣第 327 题「区间和的个数」:

327. 区间和的个数 | 力扣 | LeetCode |  🔴

给你一个整数数组 nums 以及两个整数 lower  upper 。求数组中,值位于范围 [lower, upper] (包含 lower  upper)之内的 区间和的个数 

区间和 S(i, j) 表示在 nums 中,位置从 i  j 的元素之和,包含 i  j (i  j)。

示例 1:
输入:nums = [-2,5,-1], lower = -2, upper = 2
输出:3
解释:存在三个区间:[0,0]、[2,2] 和 [0,2] ,对应的区间和分别是:-2 、-1 、2 。

示例 2:

输入:nums = [0], lower = 0, upper = 0
输出:1

提示:

  • 1 <= nums.length <= 105
  • -231 <= nums[i] <= 231 - 1
  • -105 <= lower <= upper <= 105
  • 题目数据保证答案是一个 32 位 的整数
题目来源:力扣 327. 区间和的个数

简单说,题目让你计算元素和落在 [lower, upper] 中的所有子数组的个数。

拍脑袋的暴力解法我就不说了,依然是嵌套 for 循环,这里还是说利用归并排序实现的高效算法。

首先,解决这道题需要快速计算子数组的和,所以你需要阅读前文 前缀和数组技巧,创建一个前缀和数组 preSum 来辅助我们迅速计算区间和。

我继续用比较数学的语言来表述下这道题,题目让你通过 preSum 数组求一个 count 数组,使得:

count[i] = COUNT(j) where lower <= preSum[j] - preSum[i] <= upper

然后请你求出这个 count 数组中所有元素的和。

你看,这是不是和题目描述一样?preSum 中的两个元素之差其实就是区间和。

有了之前两道题的铺垫,我直接给出这道题的解法代码吧,思路见注释:

class Solution:
    def __init__(self):
        # 定义实例变量
        self.lower = 0
        self.upper = 0
        self.count = 0
        self.temp = []

    def countRangeSum(self, nums, lower, upper):
        # 赋值实例变量
        self.lower = lower
        self.upper = upper
        
        # 构建前缀和数组,注意 int 可能溢出,用 long 存储
        preSum = [0] * (len(nums) + 1)
        for i in range(len(nums)):
            preSum[i + 1] = nums[i] + preSum[i]
            
        # 对前缀和数组进行归并排序
        self.temp = [0] * len(preSum)
        self.sort(preSum, 0, len(preSum) - 1)
        return self.count
        
        
    def sort(self, nums, lo, hi):
        if lo == hi:
            return
        mid = lo + (hi - lo) // 2
        self.sort(nums, lo, mid)
        self.sort(nums, mid + 1, hi)
        self.merge(nums, lo, mid, hi)

    def merge(self, nums, lo, mid, hi):
        # 填充 temp 数组
        for i in range(lo, hi + 1):
            self.temp[i] = nums[i]
        
        # 在合并有序数组之前加点私货(这段代码会超时)
        # for i in range(lo, mid + 1):
        #     for j in range(mid + 1, hi + 1):
        # 寻找符合条件的 nums[j]
        #
        #         delta = nums[j] - nums[i]
        #         if delta <= self.upper and delta >= self.lower:
        #             self.count += 1

        # 进行效率优化
        # 维护左闭右开区间 [start, end) 中的元素和 nums[i] 的差在 [lower, upper] 中
        start = mid + 1
        end = mid + 1
        for i in range(lo, mid + 1):
            # 如果 nums[i] 对应的区间是 [start, end),
            # 那么 nums[i+1] 对应的区间一定会整体右移,类似滑动窗口
            while start <= hi and nums[start] - nums[i] < self.lower:
                start += 1
            while end <= hi and nums[end] - nums[i] <= self.upper:
                end += 1
            self.count += end - start

        # 数组双指针技巧,合并两个有序数组
        i, j = lo, mid + 1
        for p in range(lo, hi + 1):
            if i == mid + 1:
                nums[p] = self.temp[j]
                j += 1
            elif j == hi + 1:
                nums[p] = self.temp[i]
                i += 1
            elif self.temp[i] > self.temp[j]:
                nums[p] = self.temp[j]
                j += 1
            else:
                nums[p] = self.temp[i]
                i += 1

我们依然在 merge 函数合并有序数组之前加了一些逻辑,如果看过前文 滑动窗口核心框架,这个效率优化有点类似维护一个滑动窗口,让窗口中的元素和 nums[i] 的差落在 [lower, upper] 中。

归并排序相关的题目到这里就讲完了,你现在回头体会下我在本文开头说那句话:

所有递归的算法,本质上都是在遍历一棵(递归)树,然后在节点(前中后序位置)上执行代码。你要写递归算法,本质上就是要告诉每个节点需要做什么

比如本文讲的归并排序算法,递归的 sort 函数就是二叉树的遍历函数,而 merge 函数就是在每个节点上做的事情,有没有品出点味道?

最后总结一下吧,本文从二叉树的角度讲了归并排序的核心思路和代码实现,同时讲了几道归并排序相关的算法题。这些算法题其实就是归并排序算法逻辑中夹杂一点私货,但仍然属于比较难的,你可能需要亲自做一遍才能理解。

那我最后留一个思考题吧,下一篇文章我会讲快速排序,你是否能够尝试着从二叉树的角度去理解快速排序?如果让你用一句话总结快速排序的逻辑,你怎么描述?

好了,答案在下篇文章 快速排序详解及应用 揭晓。

拓展:快速排序详解及应用

前文 归并排序算法详解 通过二叉树的视角描述了归并排序的算法原理以及应用,很多读者大呼精妙,那我就趁热打铁,今天继续用二叉树的视角讲一讲快速排序算法的原理以及运用

快速排序算法思路

首先我们看一下快速排序的代码框架:

def sort(nums: List[int], lo: int, hi: int):
    if lo >= hi:
        return
    # ****** 前序位置 ******
    # 对 nums[lo..hi] 进行切分,将 nums[p] 排好序
    # 使得 nums[lo..p-1] <= nums[p] < nums[p+1..hi]
    p = partition(nums, lo, hi)

    # 去左右子数组进行切分
    sort(nums, lo, p - 1)
    sort(nums, p + 1, hi)

其实你对比之后可以发现,快速排序就是一个二叉树的前序遍历:

# 二叉树遍历框架
def traverse(root: TreeNode):
    if not root:
        return
    # 前序位置
    print(root.val)
    traverse(root.left)
    traverse(root.right)

另外,前文 归并排序详解 用一句话总结了归并排序:先把左半边数组排好序,再把右半边数组排好序,然后把两半数组合并。

同时我提了一个问题,让你一句话总结快速排序,这里说一下我的答案:

快速排序是先将一个元素排好序,然后再将剩下的元素排好序

为什么这么说呢,且听我慢慢道来。

快速排序的核心无疑是 partition 函数,partition 函数的作用是在 nums[lo..hi] 中寻找一个切分点 p,通过交换元素使得 nums[lo..p-1] 都小于等于 nums[p],且 nums[p+1..hi] 都大于 nums[p]

img

一个元素左边的元素都比它小,右边的元素都比它大,啥意思?不就是它自己已经被放到正确的位置上了吗?

所以 partition 函数干的事情,其实就是把 nums[p] 这个元素排好序了。

一个元素被排好序了,然后呢?你再把剩下的元素排好序不就得了。

剩下的元素有哪些?左边一坨,右边一坨,去吧,对子数组进行递归,用 partition 函数把剩下的元素也排好序。

从二叉树的视角,我们可以把子数组 nums[lo..hi] 理解成二叉树节点上的值,sort 函数理解成二叉树的遍历函数

参照二叉树的前序遍历顺序,快速排序的运行过程如下 GIF:

img

你注意最后形成的这棵二叉树是什么?是一棵二叉搜索树:

img

这应该不难理解吧,因为 partition 函数每次都将数组切分成左小右大两部分,恰好和二叉搜索树左小右大的特性吻合。

你甚至可以这样理解:快速排序的过程是一个构造二叉搜索树的过程

但谈到二叉搜索树的构造,那就不得不说二叉搜索树不平衡的极端情况,极端情况下二叉搜索树会退化成一个链表,导致操作效率大幅降低。

快速排序的过程中也有类似的情况,比如我画的图中每次 partition 函数选出的切分点都能把 nums[lo..hi] 平分成两半,但现实中你不见得运气这么好。

如果你每次运气都特别背,有一边的元素特别少的话,这样会导致二叉树生长不平衡:

img

这样的话,时间复杂度会大幅上升,后面分析时间复杂度的时候再细说。

我们为了避免出现这种极端情况,需要引入随机性

常见的方式是在进行排序之前对整个数组执行 洗牌算法 进行打乱,或者在 partition 函数中随机选择数组元素作为切分点,本文会使用前者。

快速排序代码实现

明白了上述概念,直接看快速排序的代码实现:

import random

class Quick:
    @staticmethod
    def sort(nums: List[int]):
        # 为了避免出现耗时的极端情况,先随机打乱
        random.shuffle(nums)
        # 排序整个数组(原地修改)
        Quick.sort_(nums, 0, len(nums) - 1)

    @staticmethod
    def sort_(nums: List[int], lo: int, hi: int):
        if lo >= hi:
            return
        # 对 nums[lo..hi] 进行切分
        # 使得 nums[lo..p-1] <= nums[p] < nums[p+1..hi]
        p = Quick.partition(nums, lo, hi)

        Quick.sort_(nums, lo, p - 1)
        Quick.sort_(nums, p + 1, hi)
    
    # 对 nums[lo..hi] 进行切分
    @staticmethod
    def partition(nums: List[int], lo: int, hi: int) -> int:
        pivot = nums[lo]
        # 关于区间的边界控制需格外小心,稍有不慎就会出错
        # 我这里把 i, j 定义为开区间,同时定义:
        # [lo, i) <= pivot;(j, hi] > pivot
        # 之后都要正确维护这个边界区间的定义
        i, j = lo + 1, hi
        # 当 i > j 时结束循环,以保证区间 [lo, hi] 都被覆盖
        while i <= j:
            while i < hi and nums[i] <= pivot:
                i += 1
                # 此 while 结束时恰好 nums[i] > pivot
            while j > lo and nums[j] > pivot:
                j -= 1
                # 此 while 结束时恰好 nums[j] <= pivot

            if i >= j:
                break
            # 此时 [lo, i) <= pivot && (j, hi] > pivot
            # 交换 nums[j] 和 nums[i]
            nums[i], nums[j] = nums[j], nums[i]
            # 此时 [lo, i] <= pivot && [j, hi] > pivot
        # 最后将 pivot 放到合适的位置,即 pivot 左边元素较小,右边元素较大
        nums[lo], nums[j] = nums[j], nums[lo]
        return j

这里啰嗦一下核心函数 partition 的实现,正如前文 二分搜索框架详解 所说,想要正确寻找切分点非常考验你对边界条件的控制,稍有差错就会产生错误的结果。

处理边界细节的一个技巧就是,你要明确每个变量的定义以及区间的开闭情况。具体的细节看代码注释,建议自己动手实践。

贴一个快速排序过程的可视化动画,方便大家理解算法运行的过程:

 算法可视化面板

复杂度分析及其他要点

接下来分析一下快速排序的时间复杂度。

显然,快速排序的时间复杂度主要消耗在 partition 函数上,因为这个函数中存在循环。

所以 partition 函数到底执行了多少次?每次执行的时间复杂度是多少?总的时间复杂度是多少?

和归并排序类似,需要结合之前画的这幅图来从整体上分析:

img

partition 执行的次数是二叉树节点的个数,每次执行的复杂度就是每个节点代表的子数组 nums[lo..hi] 的长度,所以总的时间复杂度就是整棵树中「数组元素」的个数

假设数组元素个数为 N,那么二叉树每一层的元素个数之和就是 O(N);切分点 p 每次都落在数组正中间的理想情况下,树的层数为 O(logN),所以理想的总时间复杂度为 O(NlogN)。

由于快速排序没有使用任何辅助数组,所以空间复杂度就是递归堆栈的深度,也就是树高 O(logN)。

当然,我们之前说过快速排序的效率存在一定随机性,如果每次 partition 切分的结果都极不均匀:

img

快速排序就退化成选择排序了,树高为 O(N),每层节点的元素个数从 N 开始递减,总的时间复杂度为:

N + (N - 1) + (N - 2) + ... + 1 = O(N^2)

所以我们说,快速排序理想情况的时间复杂度是 O(NlogN),空间复杂度 O(logN),极端情况下的最坏时间复杂度是 O(N2),空间复杂度是 O(N)。

不过大家放心,经过随机化的 partition 函数很难出现极端情况,所以快速排序的效率还是非常高的。

还有一点需要注意的是,快速排序是「不稳定排序」,与之相对的,前文讲的 归并排序 是「稳定排序」

对于序列中的相同元素,如果排序之后它们的相对位置没有发生改变,则称该排序算法为「稳定排序」,反之则为「不稳定排序」。

如果单单排序 int 数组,那么稳定性没有什么意义。但如果排序一些结构比较复杂的数据,那么稳定排序就有更大的优势了。

比如说你有若干订单数据,已经按照订单号排好序了,现在你想对订单的交易日期再进行排序:

如果用稳定排序算法(比如归并排序),那么这些订单不仅按照交易日期排好了序,而且相同交易日期的订单的订单号依然是有序的。

但如果你用不稳定排序算法(比如快速排序),那么虽然排序结果会按照交易日期排好序,但相同交易日期的订单的订单号会丧失有序性。

在实际工程中我们经常会将一个复杂对象的某一个字段作为排序的 key,所以应该关注编程语言提供的 API 底层使用的到底是什么排序算法,是稳定的还是不稳定的,这很可能影响到代码执行的效率甚至正确性

说了这么多,快速排序算法应该算是讲明白了,力扣第 912 题「排序数组」就是让你对数组进行排序,我们可以直接套用快速排序的代码模板:

class Solution:
    def sortArray(self, nums: List[int]) -> List[int]:
        # 归并排序对数组进行原地排序
        Quick.sort(nums)
        return nums

class Quick:
    # 见上文

快速选择算法

快速排序算法还有一些有趣的变体,比如快速选择算法(Quick Select),主要场景是寻找第 kk 大的元素或者求中位数(k=N/2k=N/2)。

力扣第 215 题「数组中的第 K 个最大元素」就是一道类似的题目,函数签名如下:

def findKthLargest(nums: List[int], k: int) -> int:

题目要求我们寻找k 个最大的元素,稍微有点绕,意思是去寻找 nums 数组降序排列后排名第 k 的那个元素。

比如输入 nums = [2,1,5,4], k = 2,算法应该返回 4,因为 4 是 nums 中第 2 个最大的元素。

这种问题有两种解法,一种是 二叉堆(优先队列) 的解法,另一种就是快速选择算法,我们分别来看。

二叉堆的解法比较简单,但时间复杂度稍高,直接看代码好了:

import heapq
class Solution:
    def findKthLargest(self, nums, k):
        # 小顶堆,堆顶是最小元素
        pq = []
        for e in nums:
            # 每个元素都要过一遍二叉堆
            heapq.heappush(pq, e)
            # 堆中元素多于 k 个时,删除堆顶元素
            if len(pq) > k:
                heapq.heappop(pq)
        # pq 中剩下的是 nums 中 k 个最大元素,
        # 堆顶是最小的那个,即第 k 个最大元素
        return pq[0]

二叉堆(优先队列)是一种能够自动排序的数据结构,我们前文 手把手实现二叉堆数据结构 实现过这种结构,我就默认大家熟悉它的特性了。

核心思路就是把小顶堆 pq 理解成一个筛子,较大的元素会沉淀下去,较小的元素会浮上来;当堆大小超过 k 的时候,我们就删掉堆顶的元素,因为这些元素比较小,而我们想要的是前 k 个最大元素嘛。

nums 中的所有元素都过了一遍之后,筛子里面留下的就是最大的 k 个元素,而堆顶元素是堆中最小的元素,也就是「第 k 个最大的元素」。

思路很简单吧,唯一注意的是,Java 的 PriorityQueue 默认实现是小顶堆,有的语言的优先队列可能默认是大顶堆,可能需要做一些调整。

二叉堆插入和删除的时间复杂度和堆中的元素个数有关,在这里我们堆的大小不会超过 k,所以插入和删除元素的复杂度是 O(logk),再套一层 for 循环,假设数组元素总数为 N,总的时间复杂度就是 O(Nlogk)。

这个解法的空间复杂度很显然就是二叉堆的大小,为 O(k)。

快速选择算法是快速排序的变体,效率更高,面试中如果能够写出快速选择算法,肯定是加分项。

首先,题目问「第 k 个最大的元素」,相当于数组升序排序后「排名第 n - k 的元素」,为了方便表述,后文另 k' = n - k

如何知道「排名第 k' 的元素」呢?其实在快速排序算法 partition 函数执行的过程中就可以略见一二。

我们刚说了,partition 函数会将 nums[p] 排到正确的位置,使得 nums[lo..p-1] < nums[p] < nums[p+1..hi]

这时候,虽然还没有把整个数组排好序,但我们已经让 nums[p] 左边的元素都比 nums[p] 小了,也就知道 nums[p] 的排名了。

那么我们可以把 pk' 进行比较,如果 p < k' 说明第 k' 大的元素在 nums[p+1..hi] 中,如果 p > k' 说明第 k' 大的元素在 nums[lo..p-1]

进一步,去 nums[p+1..hi] 或者 nums[lo..p-1] 这两个子数组中执行 partition 函数,就可以进一步缩小排在第 k' 的元素的范围,最终找到目标元素。

这样就可以写出解法代码:

import random

class Solution:
    
    def findKthLargest(self, nums: List[int], k: int) -> int:
        # 首先随机打乱数组
        random.shuffle(nums)
        lo, hi = 0, len(nums) - 1
        # 转化成「排名第 k 的元素」
        k = len(nums) - k
        while lo <= hi:
            # 在 nums[lo..hi] 中选一个切分点
            p = self.partition(nums, lo, hi)
            if p < k:
                # 第 k 大的元素在 nums[p+1..hi] 中
                lo = p + 1
            elif p > k:
                # 第 k 大的元素在 nums[lo..p-1] 中
                hi = p - 1
            else:
                # 找到第 k 大元素
                return nums[p]
        return -1

    # 对 nums[lo..hi] 进行切分
    def partition(self, nums: List[int], lo: int, hi: int) -> int:
        # 见前文
        pass

这个代码框架其实非常像我们前文 二分搜索框架 的代码,这也是这个算法高效的原因,但是时间复杂度为什么是 O(N) 呢?

显然,这个算法的时间复杂度也主要集中在 partition 函数上,我们需要估算 partition 函数执行了多少次,每次执行的时间复杂度是多少。

最好情况下,每次 partition 函数切分出的 p 都恰好是正中间索引 (lo + hi) / 2(二分),且每次切分之后会到左边或者右边的子数组继续进行切分,那么 partition 函数执行的次数是 logN,每次输入的数组大小缩短一半。

所以总的时间复杂度为:

// 等比数列
N + N/2 + N/4 + N/8 + ... + 1 = 2N = O(N)

当然,类似快速排序,快速选择算法中的 partition 函数也可能出现极端情况,最坏情况下 p 一直都是 lo + 1 或者一直都是 hi - 1,这样的话时间复杂度就退化为 O(N2)了:

N + (N - 1) + (N - 2) + ... + 1 = O(N^2)

这也是我们在代码中使用 shuffle 函数的原因,通过引入随机性来避免极端情况的出现,让算法的效率保持在比较高的水平。随机化之后的快速选择算法的复杂度可以认为是 O(N)。

到这里,快速排序算法和快速选择算法就讲完了,从二叉树的视角来理解思路应该是不难的,但 partition 函数对细节的把控需要你多花心思去理解和记忆。

最后你可以比较一下快速排序和前文讲的 归并排序 并且可以说说你的理解:为什么快速排序是不稳定排序,而归并排序是稳定排序?


文章作者: Mealsee
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Mealsee !
  目录