yhx-12243 的 NTT 到底写了些什么(详细揭秘)

这是 yhx-12243 的 NTT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
inline int & reduce(int &x) {return x += x >> 31 & mod;}
inline int & neg(int &x) {return x = (!x - 1) & (mod - x);}
u64 PowerMod(u64 a, int n, u64 c = 1) {for (; n; n >>= 1, a = a * a % mod) if (n & 1) c = c * a % mod; return c;}
namespace poly_base {
int l, n; u64 iv; vec w2;
void init(int n = N, bool dont_calc_factorials = true) {
int i, t;
for (inv[1] = 1, i = 2; i < n; ++i) inv[i] = u64(mod - mod / i) * inv[mod % i] % mod;
if (!dont_calc_factorials) for (*finv = *fact = i = 1; i < n; ++i) fact[i] = (u64)fact[i - 1] * i % mod, finv[i] = (u64)finv[i - 1] * inv[i] % mod;
t = min(n > 1 ? lg2(n - 1) : 0, 21),
*w2 = 1, w2[1 << t] = PowerMod(31, 1 << (21 - t));
for (i = t; i; --i) w2[1 << (i - 1)] = (u64)w2[1 << i] * w2[1 << i] % mod;
for (i = 1; i < n; ++i) w2[i] = (u64)w2[i & (i - 1)] * w2[i & -i] % mod;
}
inline void NTT_init(int len) {n = 1 << (l = len), iv = mod - (mod - 1) / n;}
void DIF(int *a) {
int i, *j, *k, len = n >> 1, R, *o;
for (i = 0; i < l; ++i, len >>= 1)
for (j = a, o = w2; j != a + n; j += len << 1, ++o)
for (k = j; k != j + len; ++k)
R = (u64)*o * k[len] % mod, reduce(k[len] = *k - R), reduce(*k += R - mod);
}
void DIT(int *a) {
int i, *j, *k, len = 1, R, *o;
for (i = 0; i < l; ++i, len <<= 1)
for (j = a, o = w2; j != a + n; j += len << 1, ++o)
for (k = j; k != j + len; ++k)
reduce(R = *k + k[len] - mod), k[len] = u64(*k - k[len] + mod) * *o % mod, *k = R;
}
inline void DNTT(int *a) {DIF(a);}
inline void IDNTT(int *a) {
DIT(a), std::reverse(a + 1, a + n);
for (int i = 0; i < n; ++i) a[i] = a[i] * iv % mod;
}
}

它为什么跑这么快?DIT 和 DIF 在干啥?预处理的原根为何和大多数人的不一样?这篇文章将为你解开这一奥秘(

先来看 init 函数 w2[1 << t] = PowerMod(31, 1 << (21 - t)); 为什么是 \(31\)

我们发现 \(31^{2^{23}}=1\) 同时它模 \(998244353\) 的阶是 \(2^{23}\) 的倍数,也就是说它在进行 NTT 时和 \(3^{119}\) 具有相似的性质,事实上,这里的确可以换为 \(3^{119}\)

平时我的写法都要预处理 \(21\) 种原根的次幂,为什么这里只用处理一种原根呢?我们将 \(31\) 改为 \(3^{119}\) 输出一下这段代码预处理的原根前 \(8\) 项,发现结果如下:

1
1 911660635 372528824 488723995 929031873 373294451 628914303 661054123

再来看平常写法预处理的原根:

1
2
3
4
1: 1
2: 1 911660635
4: 1 372528824 911660635 488723995
8: 1 929031873 372528824 628914303 911660635 373294451 488723995 661054123

我们发现对这一结果蝴蝶变换(二进制翻转)可以得到如下结果:

1
2
3
4
1: 1
2: 1 911660635
4: 1 911660635 372528824 488723995
8: 1 911660635 372528824 488723995 929031873 373294451 628914303 661054123

我们发现 \(1\)\(2\) 的前缀,\(2\)\(4\) 的前缀……

经过冷静思考,我们发现这是显然的,蝴蝶变换是 \(0\) 不动,偶数放左边,奇数放右边,分别进行少一位的蝴蝶变换,而根据 \(\omega_{2n}^{2i}=\omega_n^i\) 所以它前一半就是对 \(\frac{n}{2}\) 范围的原根做蝴蝶变换的结果。

代码在做什么也很好懂了,预处理出 \(g^{2^k}\) 放在 \(2^{21-k}\) 处(即蝴蝶变换后的结果),再递推得到其他结果(\(g^{2^j+2^k}=g^{2^j}\times g^{2^k}\),二进制翻转后也可以这样找每个为 \(1\) 的位乘上)。

这样预处理原根有什么用?等下就知道了。

我们还要知道它的基本原理:DIT/DIF。在 rushcheyo 学长《转置原理及其应用》中我们了解到 DIT(decimation in time,按时域抽取)-FFT 可以将蝴蝶变换后的系数向量转化为点值向量; DIF(decimation in frequency,按频域抽取)-FFT 可以将系数向量转化为蝴蝶变换后的点值向量,二者互为置换。

我们发现可以用 DIF 实现 DFT,用 DIT 实现 IDFT 于是我们就不用进行蝴蝶变换了。

这是我写的一份朴素的 DIT/DIF-NTT:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
void init_Poly() {
for (int l = 1; l < (1 << 21); l <<= 1) {
gw[l] = 1;
int gn = pow(g, (Mod - 1) / (l << 1), Mod);
for (int j = 1; j < l; ++j) {
gw[l | j] = 1ll * gw[l | (j - 1)] * gn % Mod;
}
}
}
void DIT(int *A, int lim, bool flag) {
for (int l = 1; l < lim; l <<= 1) {
int *k = A;
for (int i = 0; i < lim; i += (l << 1), k += (l << 1)) {
int *x = k;
for (int j = 0, *g = gw + l; j < l; ++j, ++x, ++g) {
int o = 1ll * x[l] * *g % Mod;
x[l] = (*x + Mod - o) % Mod, *x = (*x + o) % Mod;
}
}
}
int iv = pow(lim, Mod - 2, Mod);
for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * iv % Mod;
std::reverse(A + 1, A + lim);
}
void DIF(int *A, int lim, bool flag) {
for (int l = lim / 2; l >= 1; l >>= 1) {
int *k = A;
for (int i = 0; i < lim; i += (l << 1), k += (l << 1)) {
int *x = k;
for (int j = 0, *g = gw + l; j < l; ++j, ++x, ++g) {
int o = x[l];
x[l] = 1ll * (*x + Mod - o) * *g % Mod, *x = (*x + o) % Mod;
}
}
}
}

这里的原根是最朴素的处理方式,而在进行 DIT/DIF 的时候,我们需要移动 \(\operatorname{O}(n\log n)\) 次原根,而 yhx-12243 的 DIT/DIF 只需要移动 \(\operatorname{O}(n)\) 次。

我们还发现一件神奇的事:yhx-12243 的 DIT 除了最外层 \(len\) 的枚举顺序,似乎都在做 DIF,而 DIF 除了最外层 \(len\) 的枚举顺序,似乎都在做 DIT!

这是一张 DIT-FFT 和 DIF-FFT 的示意图:

DIT-FFT 和 DIF-FFT 的示意图

我们观察到 DIT-FFT 时如果对系数向量进行了蝴蝶变换,对 \((0,4)\) 操作变为了对 \((0,1)\) 操作,对 \((4,6)\) 操作变为了对 \((1,3)\) 操作,如果不对系数向量做蝴蝶变换并保持原先的操作呢(即仍然是对 \((0,4)\) 操作,对 \((4,6)\) 操作)?好像这样仍然会得到一个点值数组,这个点值数组正是蝴蝶变换后的点值数组!

原因是简单的:观察到蝴蝶变换的置换 \(A\) 有:\(A^{-1}=A\) 对于输入的系数数组做这一置换,运算过程不变,那么答案也应当也被做了该置换,于是 \(A\circ A=I\)(输入),\(I\circ A=A\)(答案)。

而原先要找的原根,也要对应的蝴蝶变换一下,这时候预处理蝴蝶变换后的原根的作用就体现出来了!

更为重要的是,对于一个 \(len\) 覆盖到的范围,所用的原根次幂是相同的(例如第一层变换中的 \((0,4),(1,5),(2,6),(3,7)\),第二层变换中的 \((0,2),(1,3)\)\((4,6),(5,7)\)

以上内容可以手画一下长为 \(16\) 的 DIT-FFT 来加深理解。

于是按从大到小枚举 \(len\) 的顺序做 DIT,干的就是 DIF 的事,同理我们也可以得到按从小到大枚举 \(len\) 的顺序做 DIF,干的就是 DIT 的事,而这种做法因为只需要移动 \(T(n+\frac{n}{2}+\frac{n}{4}+\cdots)=\operatorname{O}(n)\) 次原根所以会比原先快一些。

下面进行一些可能并不靠谱的效率差异比较(以下三份代码都使用 unsigned long long 优化,即用 ull 存储中间结果减少取模):

  1. 朴素 FFT 279.439 ms,代码 2.43 KB
  2. DIT-DIF FFT 212.99 ms,代码 2.93 KB
  3. 优化 DIT-DIF FFT 192.85 ms,代码 2.94 KB

可见 DIT-DIF FFT 相较于朴素 FFT 相比,有较大优化,而优化 DIT-DIF FFT 相较于 DIT-DIF FFT 有小幅度优化,且代码不长,实现难度不大,不失为一种较好的简单 NTT 实现方式。


yhx-12243 的 NTT 到底写了些什么(详细揭秘)
https://blog.seniorious.cc/2021/yhx-12243-NTT/
作者
Seniorious
发布于
2021年4月2日
许可协议