LeetCode第 330 场周赛T2:猴子碰撞的方法数 中遇到了快速幂的问题,于是写了这篇笔记。

本文章引用了 知乎 用户 Pecco 的文章:算法学习笔记(4):快速幂 的部分内容。

要求

计算anmodma ^ n \bmod m ,以下假定 m = 1e9 + 7

朴素的想法

1
2
3
4
5
6
7
8
9
typedef long long ll;
int MOD = 1e9 + 7;
ll pow1(ll a, ll n)
{
ll rst = 1;
while(n--)
rst = (rst * a) % MOD;
return rst;
}

时间复杂度O(n)O(n) ,空间复杂度O(1)O(1)

递归快速幂

根据 n 的奇偶性,可以将计算ana^n 转化为计算an1a^{n - 1}an2a^{\frac{n}{2}} 的子问题,子问题又可以转化为新的子问题:

an={aan1,if n is oddan2an2,if n is even but not 01,if n is 0a^n = \begin{cases} a \cdot a^{n - 1} , & \text{if }n\text{ is odd} \\ a^{\frac{n}{2}} \cdot a^{\frac{n}{2}}, & \text{if }n\text{ is even but not 0} \\ 1, & \text{if }n\text{ is 0} \end{cases}

可以得到一个递归算法:

1
2
3
4
5
6
7
8
9
10
11
12
typedef long long ll;
int MOD = 1e9 + 7;
ll pow2(ll a, ll n)
{
if(n == 0) // n is 0
return 1;
if(n & 1) // n is odd
return ((pow2(a, n - 1) * a) % MOD);
// n is even but not 0
ll t = pow2(a, n >> 1);
return (t * t) % MOD;
}

时间复杂度O(logn)O(\log n) ,空间复杂度O(logn)O(\log n)

非递归快速幂

假设 n = 1010 的二进制为 1010B, 则a10=a(1010)2=a(1000)2a(10)2=a23a21a^{10} = a^{(1010)_{2}} = a^{(1000)_2} \cdot a^{(10)_2} = a^{2^3} \cdot a ^ {2^1}

一般地,若:

n=(ntnt1n1n0)2=nt2t+nt12t1++n121+n020\begin{align} n & = (n_tn_{t-1} \cdots n_1n_0)_2 \nonumber\\ & = n_t2^t + n_{t-1}2^{t-1} + \cdots + n_12^1 + n_02^0 \nonumber \end{align}

则:

an=ant2t+nt12t1++n121+n020=ant2t×ant12t1××an121×an020\begin{alignat}{2} a^n & = a^{n_t2^t + n_{t-1}2^{t-1} + \cdots + n_12^1 + n_02^0} \nonumber\\ & = a^{n_t2^t} \times a^{n_{t-1}2^{t-1}} \times \cdots \times a^{n_12^1} \times a^{n_02^0} \nonumber \end{alignat}

我们遍历 n 的二进制(从低位到高位),如果 n 的二进制的第 t 位为 1,则让结果乘以a2ta^{2^t} ,而将a2ta ^ {2^t} 转化为a2t+1a^{2^{t + 1}} 只需要计算一次平方:a2t+1=(a2t)2a^{2^{t + 1}} = {(a ^ {2^t})} ^ 2 ,那么就可以使用循环将计算ana ^ {n} 的时间复杂度降低为O(logn)O(\log n)n 的二进制的位数为log2n+1\lfloor \log_{2}n\rfloor + 1):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
typedef long long ll;
int MOD = 1e9 + 7;
ll pow3(ll a, ll n)
{
ll rst = 1;
while (n)
{
// 第 i 次循环时,n 的二进制的最后一位为函数调用时的 n 的二进制的第 i 位,i 从 0 开始计
// 第 i 次循环时,a 已经成为函数调用时的 a 的 2^i 次方
if (n & 1) // 若函数调用时的 n 的二进制中第 i 位为 1
rst = (rst * a) % MOD; // rst *= a ^ (2 ^ i),注意代码中的 a 已经不是最开始调用函数时的 a 了
a = (a * a) % MOD; // a ^ (2 ^ (i + 1)) = (a ^ (2 ^ i)) ^ 2
n >>= 1; // 以便下一个循环中,取 n 的二进制的下一位
}
return rst;
}

时间复杂度O(logn)O(\log n) ,空间复杂度O(1)O(1)

对比

运行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
int main()
{
long long a = 2, n = 1000000000, m = 1e9 + 7, rst;
DWORD time_star, time_end;

cout << "*********** pow1 ************" << endl;
time_star = GetTickCount();
rst = pow1(a, n);
time_end = GetTickCount();
cout << "rst: " << rst << ", time: " << time_end - time_star << "ms" << endl;

cout << "*********** pow2 ************" << endl;
time_star = GetTickCount();
rst = pow2(a, n);
time_end = GetTickCount();
cout << "rst: " << rst << ", time: " << time_end - time_star << "ms" << endl;

cout << "*********** pow3 ************" << endl;
time_star = GetTickCount();
rst = pow3(a, n);
time_end = GetTickCount();
cout << "rst: " << rst << ", time: " << time_end - time_star << "ms" << endl;
return 0;
}

输出:

1
2
3
4
5
6
*********** pow1 ************
rst: 140625001, time: 5797ms
*********** pow2 ************
rst: 140625001, time: 0ms
*********** pow3 ************
rst: 140625001, time: 0ms

拓展

泛型非递归快速幂

在计算ana ^ {n} 时,若 a 的类型 a 支持 乘法,并且满足 乘法结合律 (因为快速幂改变了乘法的运算顺序)便可以使用快速幂,例如传说中的矩阵快速幂。

实现方式与之前的方法非常类似。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
typedef long long ll;
template <typename T>
T powT(T a, ll n)
{
T ans = 1; // 赋值为乘法单位元,可能要根据构造函数修改,例如数字的乘法单位元为 1,矩阵的乘法单位元为单位矩阵
while (n)
{
if (n & 1)
ans = ans * a;
a = a * a;
n >>= 1;
}
return ans;
}

时间复杂度O(clogn)O(c \log n)cc 为此类型计算一次乘法的时间复杂度,空间复杂度O(1)O(1) (结果不应看作额外空间)。

矩阵快速幂

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
typedef long long ll;
int MOD = 1e9 + 7;
struct matrix // 2 * 2 矩阵,也可以使用二维数组进行模拟,但也要重载乘法
{
ll a1, a2, b1, b2;
matrix(ll a1, ll a2, ll b1, ll b2) : a1(a1), a2(a2), b1(b1), b2(b2) {}
matrix operator*(const matrix &y) // 重载乘法
{
matrix ans((a1 * y.a1 + a2 * y.b1) % MOD,
(a1 * y.a2 + a2 * y.b2) % MOD,
(b1 * y.a1 + b2 * y.b1) % MOD,
(b1 * y.a2 + b2 * y.b2) % MOD);
return ans;
}
};

matrix powM(matrix a, ll n)
{
matrix ans(1, 0, 0, 1); // 单位矩阵
while (n)
{
if (n & 1)
ans = ans * a;
a = a * a;
n >>= 1;
}
return ans;
}

应用-计算斐波那契数列

斐波那契数列:

Fn={1,(n2)Fn1+Fn2,(n>2)F_n = \begin{cases} 1, & (n \leqslant 2) \\ F_{n - 1} + F_{n - 2}, & (n > 2) \end{cases}

计算FnmodmF_n \bmod m ,以下假定 m = 1e9 + 7

常规非递归算法

1
2
3
4
5
6
7
8
9
10
11
12
13
typedef long long ll;
int MOD = 1e9 + 7;
ll Fib(ll n)
{
ll a = 0, b = 1, c = 1;
while(--n)
{
c = (a + b) % MOD;
a = b;
b = c;
}
return c;
}

时间复杂度O(n)O(n) ,空间复杂度O(1)O(1)

矩阵快速幂算法

A=(0111)A = \begin{pmatrix} 0 & 1 \\ 1 & 1 \end{pmatrix}

则有:

(FnFn+1)=A(Fn1Fn)=A2(Fn2Fn1)==An1(11)\begin{pmatrix} F_{n} \\ F_{n + 1} \end{pmatrix} = A \begin{pmatrix} F_{n-1} \\ F_{n} \end{pmatrix} = A^2 \begin{pmatrix} F_{n-2} \\ F_{n-1} \end{pmatrix} = \cdots = A ^ {n-1} \begin{pmatrix} 1 \\ 1 \end{pmatrix}

An1=(a1a2b1b2)A^{n - 1} = \begin{pmatrix} a1 & a2 \\ b1 & b2 \end{pmatrix}

即有:

Fn=a1+a2F_n = a1 + a2

1
2
3
4
5
6
7
8
typedef long long ll;
int MOD = 1e9 + 7;
ll Fib(ll n)
{
matrix A(0, 1, 1, 1);
A = powM(A, n - 1); // A^(n - 1)
return (A.a1 + A.a2) % MOD;
}

时间复杂度O(logn)O(\log n) ,空间复杂度O(1)O(1)

对比

运行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
int main()
{
long long a = 2, n = 1000000000, m = 1e9 + 7, rst;
DWORD time_star, time_end;

cout << "*********** Fib ************" << endl;
time_star = GetTickCount();
rst = Fib(n);
time_end = GetTickCount();
cout << "rst: " << rst << ", time: " << time_end - time_star << "ms" << endl;

cout << "*********** FibM ************" << endl;
time_star = GetTickCount();
rst = FibM(n);
time_end = GetTickCount();
cout << "rst: " << rst << ", time: " << time_end - time_star << "ms" << endl;

return 0;
}

输出:

1
2
3
4
*********** Fib ************
rst: 21, time: 5297ms
*********** FibM ************
rst: 21, time: 0ms