[数解算法] 扔鸡蛋问题
2022-07-09
数学
算法
👋 ‍️‍️阅读
❤️ 喜欢
💬 评论

[数解算法] 扔鸡蛋问题

扔鸡蛋问题是一道经典的算法题,LeetCode链接

  • 一般做法是动态规划,时间复杂度O(KF2)O(KF^2),对应K个蛋F层楼。
  • 优化后可以达到O(KF)O(KF)

该题目下有很多算法详解,包括官方解答。

但是今天我们不聊动态规划,从数学的角度解决这个问题。

初步分析

kk个鸡蛋,ff层楼,求最小次数n(k,f)n(k,f)

易知,n(1,f)=fn(1,f) = f,即一个鸡蛋只能从一楼开始逐层尝试。

然而,当k=2k=2的时候,问题就难以通过简单的心算来得到答案。但是容易观察到:

  • n(2,2)=2n(2,2) = 2
  • n(2,3)=2n(2,3) = 2

即,2层楼和3层楼都需要扔2次。这一现象也很直观,不考虑蛋碎的情况,2次二分查找最多可以遍历3项。

同时,我们发现函数n(2,f)n(2,f)不是严格单调的,而是分段阶梯式,这非常不利于我们求解。

所以,我们不妨转化为问题f(k,n)f(k,n),即给定鸡蛋和次数,最多可以遍历的楼层数。如果该问题得解,我们很容易求解原问题。

最多楼层

现在我们不需要标记楼层号了,而只需要知道自己遍历了多少层。

若当前有kk蛋,nn次机会,则在一次丢蛋后

  • 剩余机会必然为n1n-1
  • 蛋可能为kk或者k1k-1,对应向上和向下继续遍历
  • 同时,当前楼层已经被遍历

所以我们有了下式

f(k,n)=1+f(k1,n1)+f(k,n1)\begin{align} f(k,n) &= 1+ f(k-1, n-1) + f(k, n-1) \end{align}

又由于1个蛋的情况我们已经得解,即

f(1,n)=n\begin{align} f(1,n) &= n \end{align}

(2)(2)带入(1)(1)中,可得

f(2,n)=1+f(1,n1)+f(2,n1)=1+n1+f(2,n1)=n+f(2,n1)\begin{aligned} f(2,n) &= 1+ f(1,n-1) + f(2, n-1) \\ &= 1 + n - 1 + f(2, n-1) \\ &= n + f(2,n-1) \end{aligned}

这不就变成了等差数列求和,所以我们有

f(2,n)=12n2+12n\begin{align} f(2,n) &= \frac 12 n^2 + \frac 12 n \end{align}

如果我们把(3)(3)也带入(1)(1)中,即可得f(3,n)f(3,n)的通项。

虽然我们可以求解,但这离我们的通解还有一些距离。

求解

虽然我们无法直接求解通项,但是易知

(我们把ff看做nn的多项式,而kk为常数)

  • f(k,n)f(k,n)kk次多项式
  • f(k,n)f(k,n)过点(i,2i1)),0<=i<=k(i,2^i-1)), 0<=i<=k

这么一来,我们就有了k+1k+1个方程求解kk次方程。

记该多项式ii次幂的系数为ak,ia_{k,i},即有:

Ak=[ak,0ak,1ak,k]TRk=[0132k1]TMk=(00010k10111kk0k1kk)MkAk=RkAk=Mk1Rk\begin{align} A_k & = \begin{bmatrix} a_{k,0} & a_{k,1} & \cdots & a_{k,k} \end{bmatrix} ^ T \\ R_k & = \begin{bmatrix} 0 & 1 & 3 & \cdots & 2^k - 1 \end{bmatrix} ^ T \\ M_k & = \begin{pmatrix} 0^0 & 0^1 & \cdots & 0^k \\ 1^0 & 1^1 & \cdots & 1^k \\ \vdots & \vdots & \ddots & \vdots \\ k^0 & k^1 & \cdots & k^k \end{pmatrix} \\ M_k \cdot A_k &= R_k \\ A_k &= M_k^{-1} R_k \end{align}

现在,我们就得到了解,对于kk个蛋,nn次机会,最多遍历楼层数为:

f(k,n)=[n0n1nk]Ak\begin{align} f(k,n) = \begin{bmatrix} n^0 & n^1 & \cdots & n^k \end{bmatrix} \cdot A_k \end{align}

该解法的时间复杂度为O(k3)O(k^3),即矩阵求逆。

但是考虑到,如果n<k或者f<2kn<k或者f<2^k时,可以直接求解,故

O(k3)<O(kln2f)O(k^3) < O(k \cdot ln^2f)

答题

至此,我们已经得到了题解,接下来我们把它转化为代码提交给LeetCode。

以下代码可以直接提交给LeetCode通过。

import numpy as np
import numpy.linalg as nl

class Solution:
    def superEggDrop(self, egg: int, floor: int) -> int:
        if 2 ** egg > floor:
            return int(np.ceil(np.log2(floor + 1)))
        M = np.array([[i ** j for j in range(egg + 1)] for i in range(egg + 1)])
        R = np.array([[2 ** i - 1] for i in range(egg + 1)])
        A = np.matmul(nl.inv(M), R)
        A[0] = -floor
        roots = np.roots(np.flip(A.T[0]))
        root = [e.real for e in roots if np.isreal(e) and e.real > 0][0].real
        return int(np.ceil(np.round(root, 6)))

其中M, R, A对应前式相同符号。 唯一需要注意的是,python求根有精度误差,所以在结果返回前进行了一次6位精度的round。

我的解答的运行时间为76ms,对比官方答案需要184ms,相差三倍。

后记

在写完本文后,我查看了LeetCode上别人的解法。受到这个答案的启发,才发现自己组合数学确实不好。

(1)(1)其实有非常简洁的形式。

我们已知组合恒等式

(nk)=(n1k1)+(n1k)\begin{align} {n \choose k} = {n-1 \choose k-1} + {n-1 \choose k} \end{align}

这与式(1)(1)非常相似。所以我们需要从中构造出一个多余的1。

将式(10)(10)k=kk=kk=1k=1展开,即

(nk)=(n1k1)+(n1k)(nk1)=(n1k2)+(n1k1)(n2)=(n11)+(n12)(n1)=(n10)+(n11)\begin{align} {n \choose k} &= {n-1 \choose k-1} + {n-1 \choose k} \\ {n \choose k-1} &= {n-1 \choose k-2} + {n-1 \choose k-1} \\ &\vdots \\ {n \choose 2} &= {n-1 \choose 1} + {n-1 \choose 2} \\ {n \choose 1} &= {n-1 \choose 0} + {n-1 \choose 1} \end{align}

我们将kk个式子相加,右边最后一个式子的第一项(n10){n-1 \choose 0}正好是我们需要的11,而去掉这一项后,此列正好剩下k1k-1,正好满足我们的形式,即

i=1k(ni)=1+i=1k1(n1i)+i=1k(n1i)\begin{align} \sum_{i=1}^k {n \choose i} &= 1+ \sum_{i=1}^{k-1} {n-1 \choose i} + \sum_{i=1}^k {n-1 \choose i} \end{align}

所以,我们得到了式(1)(1)的另一个形式

f(k,n)=i=1k(ni)\begin{align} f(k,n) &= \sum_{i=1}^k {n \choose i} \end{align}
import numpy as np
from scipy.special import comb

class Solution:
    def superEggDrop(self, egg: int, floor: int) -> int:
        if 2 ** egg > floor:
            return int(np.ceil(np.log2(floor + 1)))
        start = int(np.floor(np.log2(floor)))
        for n in range(start, floor + 1):
            if sum([comb(n, e + 1) for e in range(0,egg)]) >= floor:
                return int(n)

然而,这种解法并不会比上面的多项式解法快,因为计算一次值得复杂度就已经是O(n3)O(n^3),而我们难以估计精确值,需要在较大的区间内搜索。

在LeetCode中运行,耗时浮动在100ms~150ms。耗时比前一种解法多50%。


Copyright © 2020-2022 Dean Xu. All Rights reserved.