多项式学习笔记

超全的多项式全家桶(❁´◡`❁)

感谢Meteorshower-Y的大力支持和指导, 这是大佬的blog!!ヾ(≧▽≦*)o

多项式学习笔记

快速傅里叶变换(FFT)

快速傅里叶变换(FFT),主要用于加速多项式乘法,对于两个多项式AABB, FFT可以将朴素的O(n2)O(n^2)优化为O(nlogn)O(n \log n)

单位元

先看一下单位元的几个性质,在接下来的算法中有很大的用途。

  1. ωnk=e2πikn\omega_n ^ k = e ^{\frac{2\pi i k}{n}}
  2. ωdndk=ωnk\omega_{dn} ^ {dk} = \omega_n^k
  3. ωnk=a+bi,ωnk=abi\omega_n^k = a + bi, \omega_n^{-k} = a - bi
  4. ωnk+n2=ωnk\omega _n ^{k + \frac{n}{2}} = - \omega_n^k

以上变换均可由欧拉公式eiθ=cosθ+isinθe^{i \theta} = \cos \theta + i\sin \theta推得。

离散傅里叶变换(DFT)

离散傅里叶变换(DFT) 主要是利用分治思想,根据一个nn次的多项式可以由n+1n + 1个点唯一确定,

首先将多项式

A(x)=i=0naixiA(x) = \sum_{i=0} ^n a_i x^i

其系数进行奇偶性分类,得到,

A0(x)=a0+a2x1+a4x2+A1(x)=a1+a3x1+a5x2+A_0(x)= a_0+a_2 x^1 +a_4 x^2 + \cdots \\ A_1(x)= a_1+a_3 x^1 +a_5 x^2 + \cdots \\

所以我们可以表示为 :

A(x)=A0(x2)+xA1(x2)A(x) = A_0 (x^2) +x \cdot A_1(x^2)

ωnk\omega_n^kωnk+n2\omega_n^{k+ \frac{n}{2}}代入得:

{A(ωnk)=A0(ωn2k)+ωnkA1(ωn2k)A(ωnk+n2)=A0(ωn2k)ωnkA1(ωn2k)\left\{ \begin{aligned} &A(\omega_n^k) = A_0(\omega_n^{2k})+\omega_n^k A_1(\omega_n^{2k}) \\ &A(\omega_n^{k+ \frac{n}{2}}) = A_0(\omega_n^{2k})-\omega_n^k A_1(\omega_n^{2k}) \\ \end{aligned} \right.

同时我们可以发现两个式子只有常数不一样,递归计算即可。

时间复杂度O(nlogn)O(n \log n)

在这里我们将系数变成了点值。

离散傅里叶逆变换(IDFT)

离散傅里叶逆变换(IDFT),可以将点值快速转化为系数,从而得出结果多项式。

需要用到单位根反演:

1ni=0n1ωnxi=[xmodn=0]\frac{1}{n} \sum_{i=0}^{n-1} \omega_n^{x \ast i} = [x \bmod n =0]

证明 :

由于 ωnxi=ωnx(i1)ωnx\omega_n ^ {x \ast i} = \omega_n^ {x \ast (i-1)} \ast \omega_n^x

所以ωnxi\omega _n ^{x\ast i} 为等比数列,

1ni=0n1ωnxi={1ni=0n11i=nn=1xmodn=01n1ωnnx1ωnx=1n11x1ωnx=0xmodn0\therefore \frac{1}{n} \sum_{i=0}^{n-1} \omega_n^{x \ast i}= \left\{ \begin{aligned} &\frac{1}{n} \sum_{i=0}^{n-1} 1^i = \frac{n}{n} = 1 & x \bmod n=0\\ &\frac{1}{n} \cdot \frac{1- \omega _n ^ {n \ast x}}{1-\omega _n ^ x} = \frac{1}{n} \cdot \frac{1-1^x}{1-\omega_n^x} =0 & x\bmod n \ne 0 \end{aligned} \right.

证明

c=abci=j=0iajbij=p=0q=0apbq[(p+q)modn=0]nci=p=0q=0apbqj=0ωn(p+qi)j=j=0ωn(i)j(p=0ωnpjap)(q=0ωnqjbq)fa(j)=i=0ωnijai,fa1(j)=i=0ωn(i)jainci=j=0ωn(i)jfa(j)fb(j)=j=0ωn(i)jfc(j)=ffc1(i)设 c= a\ast b \\ \begin{aligned} c_i &= \sum_{j=0}^i a_j \cdot b_{i-j} \\ &=\sum_{p=0}\sum_{q=0} a_p \cdot b_q [(p+q) \bmod n=0] \\ nc_i &= \sum_{p=0}\sum_{q=0} a_p \cdot b_q \sum_{j=0} \omega_n^{(p+q-i)j}\\ &= \sum_{j=0}\omega_n^{(-i)j} \bigg( \sum_{p=0} \omega_n^{pj} a_p\bigg) \bigg( \sum_{q=0} \omega_n^{qj} b_q\bigg) \end{aligned} \\ 设 f_a(j) = \sum_{i=0} \omega_n^{ij} a_i , f_a^{-1}(j) =\sum_{i=0} \omega_n^{(-i)j} a_i \\ \begin{aligned} nc_i &= \sum_{j=0} \omega_n^{(-i)j}f_a(j)f_b(j) \\ &= \sum_{j=0} \omega_n^{(-i)j}f_c(j) \\ &= f_{f_c}^{-1} (i) \end{aligned}

因为 faf_a 就是 aa 在 DFT 后的结果,所以fa1f_a^{-1}就是 对应的IDFT,最后除以对应长度nn,即为所求。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <bits/stdc++.h>

using namespace std;

const int N = 4e6 + 10;
const double pi = acos(-1.0);

int n, m;

struct Complex
{
double a, b;
Complex(double x = 0, double y = 0) : a(x), b(y) {}
friend Complex operator + (Complex x, Complex y) {return Complex(x.a + y.a, x.b + y.b);}
friend Complex operator - (Complex x, Complex y) {return Complex(x.a - y.a, x.b - y.b);}
friend Complex operator * (Complex x, Complex y) {return Complex(x.a * y.a - x.b * y.b, x.b * y.a + y.b * x.a);}
};

int recover[N];

Complex F[N], G[N], H[N];

void FFT(Complex *a, int len, int type)
{
for(int i = 0; i < len; i++)
if(i < recover[i])swap(a[i], a[recover[i]]);
for(int k = 1; k < len; k <<= 1)
{
Complex x(cos(pi / k), type * sin(pi / k));
for(int i = 0; i < len; i += (k << 1))
{
Complex w(1, 0);
for(int j = 0; j < k; j++)
{
Complex y = a[i + j];
Complex z = w * a[i + j + k];
a[i + j] = y + z;
a[i + j + k] = y - z;
w = w * x;
}
}
}
if(type == -1)
for(int i = 0; i < len; i++)
a[i].a /= len;
}

int main()
{
scanf("%d%d", &n, &m);
for(int i = 0; i <= n; i++)
scanf("%lf", &F[i].a);
for(int i = 0; i <= m; i++)
scanf("%lf", &G[i].a);
int len = 1, cnt = 0;
while(len <= (n + m))len <<= 1, cnt++;
for(int i = 0; i <= len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
FFT(F, len, 1), FFT(G, len, 1);
for(int i = 0; i <= len; i++)
H[i] = F[i] * G[i];
FFT(H, len, -1);
for(int i = 0; i <= n + m; i++)
printf("%d ", (int)(H[i].a + 0.5));
return 0;
}

快速数论变换(NTT)

快速数论变换(NTT)相比于FFT虽然时间复杂度均为O(nlogn)O(n\log n),但是FFT的精度却难以保证,并且常数很大, 所以有时NTT才是更好的选择。

原根

原根定义为:设mm为正整数,aa是整数,若amodma \bmod m的阶等于φ(m)\varphi (m),则称aamodm\bmod m的一个原根。

原根有一个很重要的性质可以支持像FFT中单位根一样的运算,即:若PP为素数, 假设一个数ggPP的原根, 那么gimodPg^i \bmod P的结果两两不同。

可以得到:

ωngp1n(modp)\omega_n \equiv g^{\frac{p - 1}{n}} \pmod p

然后我们就可以将FFT中的ωn\omega _n替换为gp1ng^{\frac{p - 1}{n}}

但是注意的是NTT对模数有要求,其模数必须要满足原根的定义,否则是不能使用NTT的,比如998244353998244353就为NTT模数, 其原根为33

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <bits/stdc++.h>

using namespace std;

#define int long long

const int N = 4e6 + 10;
const int mod = 998244353;
const int g = 3;
const int gi = 332748118;

int n, m;

int F[N], G[N], H[N];

int qpow(int a, int b)
{
int t = 1;
while(b != 0)
{
if(b & 1)t = t * a % mod;
a = a * a % mod; b >>= 1;
}
return t;
}

int recover[N];

void NTT(int *a, int len, int type)
{
for(int i = 0; i < len; i++)
if(i < recover[i])swap(a[i], a[recover[i]]);
for(int k = 1; k < len; k <<= 1)
{
int x = qpow(type == 1 ? g : gi, (mod - 1) / (k << 1));
for(int i = 0; i < len; i += (k << 1))
{
int w = 1;
for(int j = 0; j < k; j++)
{
int y = a[i + j];
int z = w * a[i + j + k] % mod;
a[i + j] = (y + z) % mod;
a[i + j + k] = (y - z + mod) % mod;
w = (w * x) % mod;
}
}
}
if(type == -1)
{
int inv = qpow(len, mod - 2);
for(int i = 0; i < len; i++)
a[i] = a[i] * inv % mod;
}
}

signed main()
{
scanf("%lld%lld", &n, &m);
for(int i = 0; i <= n; i++)
scanf("%lld", &F[i]);
for(int i = 0; i <= m; i++)
scanf("%lld", &G[i]);
int len = 1, cnt = 0;
while(len <= (n + m))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
NTT(F, len, 1), NTT(G, len, 1);
for(int i = 0; i < len; i++)
H[i] = (F[i] * G[i]) % mod;
NTT(H, len, -1);
for(int i = 0; i <= n + m; i++)
printf("%lld ", H[i]);
return 0;
}

快速沃尔什变换 (FWT)

给定两个长度为2n2 ^ n的两个序列A,BA,B, 求序列CC

Ci=jk=iAj×BkC_i = \sum_{j \oplus k = i} A_j \times B_k

其中\oplus表示位运算与,或,异或。

或运算

首先求序列FWT[A]=i=ijAjFWT[A] = \sum_{i = i | j} A _ j,来求出满足条件的ii的子集,显然会有

Ci=i=jkAj×BkFWT[C]=FWT[A]×FWT[B]C_i = \sum_{i = j | k} A_j \times B_k \Rightarrow FWT[C] = FWT[A] \times FWT[B]

接下来就是考虑如何进行FWTFWT运算, 有

FWT[A]=merge(FWT[A0],FWT[A0]+FWT[A1])IFWT[A]=merge(IFWT[A0],IFWT[A1]FWT[A0])FWT[A] = merge(FWT[A_0], FWT[A_0] + FWT[A_1]) \\ IFWT[A] = merge(IFWT[A_0], IFWT[A_1] - FWT[A_0])

1
2
3
4
5
6
7
8
9
10
11
12
13
void FWT_or(int *a, int len, int type)
{
for(int k = 1; k < len; k <<= 1)
{
for(int i = 0; i < len; i += (k << 1))
{
for(int j = 0; j < k; j++)
{
a[i + j + k] = ((a[i + j + k] + a[i + j] * type + mod) % mod + mod) % mod;
}
}
}
}

与运算

同或运算,有

FWT[A]=merge(FWT[A0]+FWT[A1],FWT[A1])IFWT[A]=merge(IFWT[A0]IFWT[A1],IFWT[A1])FWT[A] = merge(FWT[A_0] + FWT[A_1], FWT[A_1])\\ IFWT[A] = merge(IFWT[A_0] - IFWT[A_1], IFWT[A_1])

1
2
3
4
5
6
7
8
9
10
11
12
13
void FWT_and(int *a, int len, int type)
{
for(int k = 1; k < len; k <<= 1)
{
for(int i = 0; i < len; i += (k << 1))
{
for(int j = 0; j < k; j++)
{
a[i + j] = ((a[i + j] + a[i + j + k] * type + mod) % mod + mod) % mod;
}
}
}
}

异或运算

推导得

FWT[A]=merge(FWT[A0]+FWT[A1],FWT[A0]FWT[A1])IFWT[A]=merge(IFWT[A0]+IFWT[A1]2,IFWT[A0]IFWT[A1]2)FWT[A] = merge(FWT[A_0] + FWT[A_1], FWT[A_0] - FWT[A_1]) \\ IFWT[A] = merge(\frac{IFWT[A_0] + IFWT[A_1]}{2}, \frac{IFWT[A_0] - IFWT[A_1]}{2}) \\

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void FWT_xor(int *a, int len, int type)
{
for(int k = 1; k < len; k <<= 1)
{
for(int i = 0; i < len; i += (k << 1))
{
for(int j = 0; j < k; j++)
{
int x = a[i + j], y = a[i + j + k];
a[i + j] = ((x + y) % mod * type + mod) % mod;
a[i + j + k] = ((x - y + mod) % mod * type + mod) % mod;
}
}
}
}

任意模数快速数论变换

普通的NTT对模数是有要求的其必须满足原根的相关定义,模数必须可以写成a2k+1a \cdot 2 ^ k + 1的形式。

比如:

469762049=7×226+1(g=3)998244353=119×223+1(g=3)1004535809=479×221+1(g=3)469762049 = 7 \times 2 ^ {26} + 1 (g = 3) \\ 998244353 = 119 \times 2 ^{23} + 1(g = 3) \\ 1004535809 = 479 \times 2 ^ {21} + 1(g = 3)

如果题目中模数为1e9+71e9 + 7,那么NTT就会受到限制,然后就可以使用任意模数NTT,(也可以称为三模数NTT),

计算时可以先找三个大质数, 分别计算结果,然后用中国剩余定理CRT合并即可。

首先记三次NTT的结果为:

ansa1(modp1)ansa2(modp2)ansa3(modp3)ans \equiv a_1 \pmod {p_1} \\ ans \equiv a_2 \pmod {p_2} \\ ans \equiv a_3 \pmod {p_3}

先合并前两个得到:

ansa4(modp1p2)ans \equiv a_4 \pmod {p_1 p_2}

将其转化为等式为:

ans=kp1p2+a4ans = k p_1 p_2 + a_4

接着求kk

k=(a3a4)p11p21(modp3)k = (a_3 - a_4)p_1^{-1} p_2 ^ {-1} \pmod {p_3}

所以:

anskp1p2+a4(modp1p2p3)ans \equiv kp_1 p_2 + a_4 \pmod {p_1p_2p_3}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include <bits/stdc++.h>

using namespace std;

#define int __int128

const int N = 4e5 + 10;
const int g = 3;

int read()
{
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-')f = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){x = x*10 + ch-'0'; ch = getchar();}
return x * f;
}

void write(int x)
{
char ch[100], len = 0;
if(x == 0)ch[++len] = '0';
while(x)ch[++len] = x%10 + '0', x /= 10;
while(len)putchar(ch[len--]);
printf(" ");
}

int p[3] = {469762049, 998244353, 1004535809};

int qpow(int a, int b, int i)
{
int t = 1;
while(b != 0)
{
if(b & 1) t = t * a % p[i];
a = a * a % p[i]; b >>= 1;
}
return t % p[i];
}

int inv(int x, int i)
{
return qpow(x, p[i] - 2, i);
}

int gi[3];

void init()
{
for(int i = 0; i < 3; i++)
gi[i] = inv(g, i);
}

int F[N], G[N], H[N];

int recover[N];

void NTT(int *a, int len, int type, int f)
{
for(int i = 0; i < len; i++)
if(i < recover[i])swap(a[i], a[recover[i]]);
for(int k = 1; k < len; k <<= 1)
{
int x = qpow(type == 1 ? g : gi[f], (p[f] - 1) / (k << 1), f);
for(int i = 0; i < len; i += (k << 1))
{
int w = 1;
for(int j = 0; j < k; j++)
{
int y = a[i + j] % p[f];
int z = w * a[i + j + k] % p[f];
a[i + j] = (y + z) % p[f];
a[i + j + k] = (y - z + p[f]) % p[f];
w = w * x % p[f];
}
}
}
if(type == -1)
{
int iv = inv(len, f);
for(int i = 0; i < len; i++)
a[i] = a[i] * iv % p[f];
}
}

int A[N], B[N], C[3][N];

void CRT(int len)
{
int M = p[0] * p[1];
for(int i = 0; i <= len; i++)
{
H[i] = (p[1] * C[0][i] % M * inv(p[1], 0) % M
+ p[0] * C[1][i] % M * inv(p[0], 1) % M) % M;
}
}

int n, m, mod;

void merge(int len)
{
for(int i = 0; i <= len; i++)
{
int k = ((C[2][i] - H[i]) % p[2] + p[2]) % p[2] * inv(p[0] * p[1], 2) % p[2];
H[i] = ((k * p[0] * p[1] % mod + H[i] % mod) % mod + mod) % mod;
}
}

void prework()
{
memcpy(A, F, sizeof(F));
memcpy(B, G, sizeof(G));
}

void update(int x, int len)
{
for(int i = 0; i < len; i++)
C[x][i] = A[i] * B[i] % p[x];
}

signed main()
{
init();
n = read(), m = read(), mod = read();
for(int i = 0; i <= n; i++)
F[i] = read();
for(int i = 0; i <= m; i++)
G[i] = read();
int len = 1, cnt = 0;
while(len <= (n + m))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
for(int i = 0; i < 3; i++)
{
prework();
NTT(A, len, 1, i), NTT(B, len, 1, i);
update(i, len);
NTT(C[i], len, -1, i);
}
CRT(n + m); merge(n + m);
for(int i = 0; i <= (n + m); i++)
write(H[i]);
return 0;
}

多项式乘法逆

定义多项式F1F ^ {-1}为多项式FF的乘法逆元,满足

FF11(modxn)F \ast F ^ {-1} \equiv 1 \pmod{x^n}

假设我们已经得知FG1(modxn2)F \ast G' \equiv 1 \pmod {x ^ {\frac{n}{2}}}, 来求FG1(modxn)F \ast G \equiv 1 \pmod {x ^ n}

FG1modxn2,FG1(modxn)FG1(modxn2)GG0(modxn2)(GG)20(modxn)G22GG+G20(modxn)\because F \ast G' \equiv 1 \mod {x ^ {\frac{n}{2}}} , F \ast G \equiv 1 \pmod {x ^ n} \\ \therefore F \ast G \equiv 1 \pmod {x ^ {\frac{n}{2}}} \\ \therefore G' - G \equiv 0 \pmod {x ^ {\frac{n}{2}}} \\ \therefore (G' - G) ^ 2 \equiv 0 \pmod {x ^ n} \\ G'^2 - 2 G G' + G^2 \equiv 0 \pmod {x ^ n} \\

接下来两边同时F\ast F

FG22G+G0(modxn)G2GFG2(modxn)F G'^2 - 2 G' + G \equiv 0 \pmod {x ^ n} \\ \therefore G \equiv 2 G' - FG'^2 \pmod {x ^ n}

然后直接递归即可, 使用NTT, 时间复杂度O(nlogn)O(n \log n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <bits/stdc++.h>

using namespace std;

#define int long long

const int N = 4e5 + 10;
const int mod = 998244353;
const int g = 3;
const int gi = 332748118;

int qpow(int a, int b)
{
int t = 1;
while(b != 0)
{
if(b & 1)t = t * a % mod;
a = a * a % mod; b >>= 1;
}
return t;
}

int inv(int x)
{
return qpow(x, mod - 2);
}

int F[N], G[N];

int recover[N];

void NTT(int *a, int len, int type)
{
for(int i = 0; i < len; i++)
if(i < recover[i])swap(a[i], a[recover[i]]);
for(int k = 1; k < len; k <<= 1)
{
int x = qpow(type == 1 ? g : gi, (mod - 1) / (k << 1));
for(int i = 0; i < len; i += (k << 1))
{
int w = 1;
for(int j = 0; j < k; j++)
{
int y = a[i + j] % mod;
int z = w * a[i + j + k] % mod;
a[i + j] = (y + z) % mod;
a[i + j + k] = (y - z + mod) % mod;
w = w * x % mod;
}
}
}
if(type == -1)
{
int iv = inv(len);
for(int i = 0; i < len; i++)
a[i] = a[i] * iv % mod;
}
}

int c[N];

void mul(int n, int *a, int *b)
{
if(n == 1)
{
b[0] = inv(a[0]);
return;
}
mul((n + 1) >> 1, a, b);
int len = 1, cnt = 0;
while(len <= (n << 1))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
for(int i = 0; i < n; i++)
c[i] = a[i];
for(int i = n; i < len; i++)
c[i] = 0;
NTT(c, len, 1), NTT(b, len, 1);
for(int i = 0; i < len; i++)
b[i] = (2 - b[i] * c[i] % mod + mod) % mod * b[i] % mod;
NTT(b, len, -1);
for(int i = n; i < len; i++)b[i] = 0;
}

signed main()
{
int n; scanf("%lld", &n);
for(int i = 0; i < n; i++)
scanf("%lld", &F[i]);
mul(n, F, G);
for(int i = 0; i < n; i++)
printf("%lld ", (G[i] % mod + mod) % mod);
return 0;
}

多项式对数函数(多项式求ln)

定义多项式对数函数为

G=ln(F)(modxn)G = \ln (F) \pmod {x ^ n}

假设我们有多项式F(x)F(x)G(x)G(x), 记G=lnF(modxn)G = \ln F \pmod {x ^ n}

GlnF(modxn)G(lnF)(modxn)G(lnF)F(modxn)GFF(modxn)G \equiv \ln F \pmod {x ^ n} \\ G'\equiv (\ln F)’ \pmod {x ^ n} \\ G' \equiv (\ln' F )\ast F ' \pmod {x ^n} \\ G' \equiv \frac{F'}{F} \pmod {x^n}

多项式求逆,再积回去就好啦。

需要用到求导:xa=axa1x ^ {a'} = ax ^ {a - 1}, 积分:xadx=1a+1xa+1\int x^a \mathrm{d}x = \frac{1}{a + 1}x ^ {a + 1}需要保证F0=1F_0 = 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <bits/stdc++.h>

using namespace std;

#define int long long

const int N = 4e5 + 10;
const int mod = 998244353;
const int g = 3;
const int gi = 332748118;

int recover[N];

int qpow(int a, int b)
{
int t = 1;
while(b != 0)
{
if(b & 1)t = t * a % mod;
a = a * a % mod; b >>= 1;
}
return t;
}

int inv(int x)
{
return qpow(x, mod - 2);
}

void NTT(int *a, int len, int type)
{
for(int i = 0; i < len; i++)
if(i < recover[i])swap(a[i], a[recover[i]]);
for(int k = 1; k < len; k <<= 1)
{
int x = qpow(type == 1 ? g : gi, (mod - 1) / (k << 1));
for(int i = 0; i < len; i += (k << 1))
{
int w = 1;
for(int j = 0; j < k; j++)
{
int y = a[i + j] % mod;
int z = w * a[i + j + k] % mod;
a[i + j] = (y + z) % mod;
a[i + j + k] = (y - z + mod) % mod;
w = w * x % mod;
}
}
}
if(type == -1)
{
int iv = inv(len);
for(int i = 0; i < len; i++)
a[i] = a[i] * iv % mod;
}
}

void inverse(int *a, int *b, int n)
{
if(n == 1)
{
b[0] = inv(a[0]);
return;
}
inverse(a, b, (n + 1) >> 1);
int len = 1, cnt = 0;
while(len <= (n << 1))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
int c[N]; memset(c, 0, sizeof(c));
for(int i = 0; i < n; i++)
c[i] = a[i];
for(int i = n; i < len; i++)
c[i] = 0;
NTT(c, len, 1), NTT(b, len, 1);
for(int i = 0; i < len; i++)
b[i] = (2 - b[i] * c[i] % mod + mod) % mod * b[i] % mod;
NTT(b, len, -1);
for(int i = n; i < len; i++)b[i] = 0;
}

void mul(int *a, int *b, int *c, int n, int m)
{
int len = 1, cnt = 0;
while(len <= (n + m))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
NTT(a, len, 1), NTT(b, len, 1);
for(int i = 0; i < len; i++)
c[i] = a[i] * b[i] % mod;
NTT(c, len, -1);
}

void add(int *a, int *b, int *c, int n, int m, int k)
{
for(int i = 0; i <= max(n, m); i++)
c[i] = (a[i] + k * b[i] + mod) % mod;
}

void diff(int *a, int *b, int n)
{
for(int i = 1; i < n; i++)
b[i - 1] = i * a[i] % mod;
b[n - 1] = 0;
}

void integ(int *a, int *b, int n)
{
for(int i = 1; i < n; i++)
b[i] = a[i - 1] * inv(i) % mod;
b[0] = 0;
}

int F[N], G[N], H[N];

void polyln(int *a, int *b, int n)
{
int f[N], h[N];
memset(f, 0, sizeof(f));
memset(h, 0, sizeof(h));
diff(a, f, n);
inverse(a, h, n);
mul(f, h, H, n, n);
integ(H, b, n);
}

int n;

signed main()
{
scanf("%lld", &n);
for(int i = 0; i < n; i++)
scanf("%lld", &F[i]);
polyln(F, G, n);
for(int i = 0; i < n; i++)
printf("%lld ", G[i]);
return 0;
}

多项式指数函数(多项式exp)

定义多项式指数函数为

G(x)=eF(x)(modxn)G(x) = e ^ {F(x)} \pmod {x ^ n}

牛顿迭代

牛顿迭代用于求函数零点,通过不断地切线逼近所求值,但最终也只是近似值,迭代的次数越多,精确度越高,误差越小。

假如我们要对一个非常大的数aa开方,手算,利用牛顿法来解决这个问题,其实本质上是求得f(x)=x2af(x) = x ^2 - a精确到整数得零点,假设我们已经求得了一个近似值x0x_0,那么我们只需要过(x0,f(x0))(x_0, f(x_0))这个点, 作这个函数图像的切线,取切线与xx轴的交点作为新的x0x_0

假设我们要求一个函数f(x)f(x)的零点, 初始近似值是x0x_0,则切线方程为

y=f(x0)(xx0)+f(x0)y = f'(x_0)(x - x_0) + f(x_0)

y=0y = 0,得到x=x0f(x0)f(x0)x = x_0 - \frac{f(x_0)}{f'(x_0)}

假设我们现在要求F(G(x))0F(G(x)) \equiv 0,然后利用上面的式子每一次令

G(x)=G0(x)F(G0(x))F(G0(x))G(x) = G_0(x) - \frac{F(G_0(x))}{F'(G_0(x))}

然后就可以很快的逼近真实值。

接下来推一下多项式exp

B(x)eA(x)(modxn)lnB(x)A(x)0(modxn)B(x) \equiv e ^ {A(x)} \pmod {x ^ n} \\ \ln B(x) - A(x) \equiv 0 \pmod {x^ n}

现在问题变为了使得F(G(x))=lnG(x)A(x)0F(G(x)) = \ln G(x) - A(x) \equiv 0

然后求导,

F(G0(x))=1G0(x)F'(G_0(x)) = \frac{1}{G_0(x)}

然后接着带入上面牛顿迭代的式子,

G(x)=G0(x)(1lnG0(x)+A(x))G(x) = {G_0(x)(1 - \ln G_0(x) + A(x))}

每次迭代,使用多项式求ln\ln,然后再做一遍多项式乘法,然后就可以得到答案,时间复杂度O(nlogn)O(n \log n)

需要保证F0=0F_0 = 0

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#include <bits/stdc++.h>

using namespace std;

#define int long long

const int N = 4e5 + 10;
const int mod = 998244353;
const int g = 3;
const int gi = 332748118;

int recover[N];

int qpow(int a, int b)
{
int t = 1;
while(b != 0)
{
if(b & 1)t = t * a % mod;
a = a * a % mod; b >>= 1;
}
return t;
}

int inv(int x)
{
return qpow(x, mod - 2);
}

void NTT(int *a, int len, int type)
{
for(int i = 0; i < len; i++)
if(i < recover[i])swap(a[i], a[recover[i]]);
for(int k = 1; k < len; k <<= 1)
{
int x = qpow(type == 1 ? g : gi, (mod - 1) / (k << 1));
for(int i = 0; i < len; i += (k << 1))
{
int w = 1;
for(int j = 0; j < k; j++)
{
int y = a[i + j] % mod;
int z = w * a[i + j + k] % mod;
a[i + j] = (y + z) % mod;
a[i + j + k] = (y - z + mod) % mod;
w = w * x % mod;
}
}
}
if(type == -1)
{
int iv = inv(len);
for(int i = 0; i < len; i++)
a[i] = a[i] * iv % mod;
}
}

void inverse(int *a, int *b, int n)
{
if(n == 1)
{
b[0] = inv(a[0]);
return;
}
inverse(a, b, (n + 1) >> 1);
int len = 1, cnt = 0;
while(len <= (n << 1))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
int c[N]; memset(c, 0, sizeof(c));
for(int i = 0; i < n; i++)
c[i] = a[i];
for(int i = n; i < len; i++)
c[i] = 0;
NTT(c, len, 1), NTT(b, len, 1);
for(int i = 0; i < len; i++)
b[i] = (2 - b[i] * c[i] % mod + mod) % mod * b[i] % mod;
NTT(b, len, -1);
for(int i = n; i < len; i++)b[i] = 0;
}

void mul(int *a, int *b, int *c, int n, int m)
{
int len = 1, cnt = 0;
while(len <= (n + m))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
NTT(a, len, 1), NTT(b, len, 1);
for(int i = 0; i < len; i++)
c[i] = a[i] * b[i] % mod;
NTT(c, len, -1);
}

void add(int *a, int *b, int *c, int n, int m, int k)
{
for(int i = 0; i <= max(n, m); i++)
c[i] = (a[i] + k * b[i] + mod) % mod;
}

void diff(int *a, int *b, int n)
{
for(int i = 1; i < n; i++)
b[i - 1] = i * a[i] % mod;
b[n - 1] = 0;
}

void integ(int *a, int *b, int n)
{
for(int i = 1; i < n; i++)
b[i] = a[i - 1] * inv(i) % mod;
b[0] = 0;
}

int F[N], G[N], H[N];

void polyln(int *a, int *b, int n)
{
int f[N], h[N];
memset(f, 0, sizeof(f));
memset(h, 0, sizeof(h));
diff(a, f, n);
inverse(a, h, n);
mul(f, h, H, n, n);
integ(H, b, n);
}

void polyexp(int *a, int *b, int n)
{
if(n == 1)
{
b[0] = 1;
return;
}
polyexp(a, b, (n + 1) >> 1);
int len = 1, cnt = 0;
while(len <= (n << 1))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
int c[N];
memset(c, 0, sizeof(c));
c[0] = 1;
int f[N]; memset(f, 0, sizeof(f));
polyln(b, f, n);
add(c, a, c, n, n, 1);
add(c, f, c, n, n, -1);
mul(c, b, c, n, n);
for(int i = 0; i < n; i++)
b[i] = c[i];
for(int i = n; i < len; i++)
b[i] = 0;
}

int n;

signed main()
{
scanf("%lld", &n);
for(int i = 0; i < n; i++)
scanf("%lld", &F[i]);
polyexp(F, G, n);
for(int i = 0; i < n; i++)
printf("%lld ", G[i]);
return 0;
}

多项式开根

多项式开根用来解决

G2(x)F(x)(modxn)G^2(x) \equiv F(x) \pmod {x^n}

假设我们有G2(x)F(x)(modxn2),H(G(x))=G2(x)FG'^2(x) \equiv F(x) \pmod {x ^ {\frac{n}{2}}}, H(G(x)) = G^2(x) - F,求G2(x)F(x)(modxn)G^2(x) \equiv F(x) \pmod {x ^ n}

G2(x)F(x)modxn2,G2(x)F(x)(modxn2)G2(x)F0(modxn2)H(G)0(modxn2)GGH(G)H(G)(modxn)GG2+F2G(modxn)G'^2 (x) \equiv F(x) \mod x ^ {\frac{n}{2}} , G^2(x) \equiv F(x) \pmod {x ^ {\frac{n}{2}}} \\ G^2(x) - F \equiv 0 \pmod {x ^ {\frac{n}{2}}} \\ H(G) \equiv 0 \pmod {x ^ {\frac{n}{2}}} \\ G \equiv G' - \frac{H(G')}{H'(G')} \pmod {x ^ n} \\ G \equiv \frac{G'^2 + F} {2G'} \pmod {x ^ n}

需要保证F0=1F_0 = 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include <bits/stdc++.h>

using namespace std;

#define int long long

const int N = 4e5 + 10;
const int mod = 998244353;
const int g = 3;
const int gi = 332748118;

int recover[N];

int qpow(int a, int b)
{
int t = 1;
while(b != 0)
{
if(b & 1)t = t * a % mod;
a = a * a % mod; b >>= 1;
}
return t;
}

int inv(int x)
{
return qpow(x, mod - 2);
}

int inv2 = inv(2);

void NTT(int *a, int len, int type)
{
for(int i = 0; i < len; i++)
if(i < recover[i])swap(a[i], a[recover[i]]);
for(int k = 1; k < len; k <<= 1)
{
int x = qpow(type == 1 ? g : gi, (mod - 1) / (k << 1));
for(int i = 0; i < len; i += (k << 1))
{
int w = 1;
for(int j = 0; j < k; j++)
{
int y = a[i + j] % mod;
int z = w * a[i + j + k] % mod;
a[i + j] = (y + z) % mod;
a[i + j + k] = (y - z + mod) % mod;
w = w * x % mod;
}
}
}
if(type == -1)
{
int iv = inv(len);
for(int i = 0; i < len; i++)
a[i] = a[i] * iv % mod;
}
}

void inverse(int *a, int *b, int n)
{
if(n == 1)
{
b[0] = inv(a[0]);
return;
}
inverse(a, b, (n + 1) >> 1);
int len = 1, cnt = 0;
while(len <= (n << 1))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
int c[N]; memset(c, 0, sizeof(c));
for(int i = 0; i < n; i++)
c[i] = a[i];
for(int i = n; i < len; i++)
c[i] = 0;
NTT(c, len, 1), NTT(b, len, 1);
for(int i = 0; i < len; i++)
b[i] = (2 - b[i] * c[i] % mod + mod) % mod * b[i] % mod;
NTT(b, len, -1);
for(int i = n; i < len; i++)b[i] = 0;
}

void mul(int *a, int *b, int *c, int n, int m)
{
int len = 1, cnt = 0;
while(len <= (n + m))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
NTT(a, len, 1), NTT(b, len, 1);
for(int i = 0; i < len; i++)
c[i] = a[i] * b[i] % mod;
NTT(c, len, -1);
}

void add(int *a, int *b, int *c, int n, int m, int k)
{
for(int i = 0; i <= max(n, m); i++)
c[i] = (a[i] + k * b[i] + mod) % mod;
}
int F[N], G[N], H[N];

void polysqrt(int *a, int *b, int n)
{
if(n == 1)
{
b[0] = 1;
return;
}
polysqrt(a, b, (n + 1) >> 1);
int len = 1, cnt = 0;
while(len <= (n << 1))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
int c[N]; memset(c, 0, sizeof(c));
int f[N]; memset(f, 0, sizeof(f));
inverse(b, f, n);
for(int i = 0; i < n; i++)
c[i] = a[i];
NTT(f, len, 1), NTT(b, len, 1), NTT(c, len, 1);
for(int i = 0; i < len; i++)
b[i] = (b[i] + c[i] * f[i] % mod) % mod * inv2 % mod;
NTT(b, len, -1);
for(int i = n; i < len; i++)
b[i] = 0;
}

int n;

signed main()
{
scanf("%lld", &n);
for(int i = 0; i < n; i++)
scanf("%lld", &F[i]);
polysqrt(F, G, n);
for(int i = 0; i < n; i++)
printf("%lld ", G[i]);
return 0;
}

多项式幂函数

多项式幂函数是用来解决

G(x)(F(x))kmodxnG(x) \equiv (F(x)) ^ k \mod x ^ n

先求一遍ln\ln然后乘以kk再使用exp\exp,就好啦。

需保证F0=1F_0 = 1

多项式的一些普通情况

多项式求ln

不保证F0=1F_0 = 1。不存在,有定理:

在模意义下当且仅当F0=1F_0 = 1F(x)F(x)有对数多项式问题。

多项式求exp

不保证F0=0F_0 = 0 。同多项式求ln\ln

多项式开根

不保证F0=1F_0 = 1,但保证F0F_0mod998244353\bmod 998244353下的二次剩余。

边界求一遍二次剩余即可。

多项式幂函数

不保证F0=1F_0 = 1。可以先找到系数不为00的一项,然后让式子除以这一项最后再乘回来就好了

F(x)k=(F(x)xt)kxtkF(x)^k = \bigg( \frac{F(x)}{x ^ t} \bigg) ^ k x ^ {tk}

分治FFT/NTT

给定序列ggff, 其中

fi=j=1ifijgjf_i= \sum_{j = 1} ^ i f _{i - j} g _ j

ff, 这里给出一个多项式求逆的方法,(找时间再补分治FFT / NTT)

F(x)=i=0fixi,G(x)=i=0gixiF(x) = \sum_{i = 0} ^ {\infty} f_i x ^ i , G(x) = \sum_{i = 0} ^ {\infty}g_i x ^ i,且g0=0g_0 = 0

所以有

F(x)G(x)=i=0j+k=ifjgk=F(x)f0x0F(x)G(x)F(x)f0(modxn)F(x)(1G(x))1(modxn)F(x) G(x) = \sum_{i = 0} ^ {\infty} \sum_{j + k = i}f_jg_k = F(x) - f_0 x ^ 0 \\ F(x)G(x) \equiv F(x) - f_0 \pmod {x ^ n} \\ F(x) \equiv (1 - G(x)) ^ {-1} \pmod {x ^ n}

下降幂多项式乘法

假设我们已知nn次多项式f(x)f(x)[0,n][0, n]的点值, 求它的下降幂表示,

f(x)=i=0nbixi=i=0nbix!(xi)!f(x) = \sum_{i = 0} ^ n b_i x^{\underline{i}} = \sum_{i = 0} ^ n b_i\frac{x!}{(x -i)!},则有

f(x)x!=i=0nbi1(xi)!=bex\frac{f(x)}{x!} = \sum_{i = 0} ^ n b_i \frac{1}{(x - i) !} = b \ast e^x

先转化为点值最后卷上exe^x即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <bits/stdc++.h>

using namespace std;

#define int long long

const int mod = 998244353;
const int g = 3;
const int gi = 332748118;
const int N = 8e5 + 10;

int recover[N];

int n, m;

int A[N], B[N], F[N], G[N], H[N];

int qpow(int a, int b)
{
int t = 1;
while(b != 0)
{
if(b & 1)t = t * a % mod;
a = a * a % mod; b >>= 1;
}
return t;
}

int inv(int x)
{
return qpow(x, mod - 2);
}

void NTT(int *a, int len, int type)
{
for(int i = 0; i < len; i++)
if(i < recover[i])swap(a[i], a[recover[i]]);
for(int k = 1; k < len; k <<= 1)
{
int x = qpow(type == 1 ? g : gi, (mod - 1) / (k << 1));
for(int i = 0; i < len; i += (k << 1))
{
int w = 1;
for(int j = 0; j < k; j++)
{
int y = a[i + j] % mod;
int z = w * a[i + j + k] % mod;
a[i + j] = (y + z) % mod;
a[i + j + k] = ((y - z) % mod + mod) % mod;
w = w * x % mod;
}
}
}
if(type == -1)
{
int iv = inv(len);
for(int i = 0; i < len; i++)
a[i] = a[i] * iv % mod;
}
}

int fac[N], ifac[N];

signed main()
{
scanf("%lld%lld", &n, &m);
for(int i = 0; i <= n; i++)
scanf("%lld", &F[i]);
for(int i = 0; i <= m; i++)
scanf("%lld", &G[i]);
int len = 1, cnt = 0, Len = max(n, m) << 1;
while(len <= (Len << 1))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
fac[0] = ifac[0] = 1;
for(int i = 1; i <= Len; i++)
fac[i] = fac[i - 1] * i % mod;
for(int i = 1; i <= Len; i++)
ifac[i] = inv(fac[i]);
for(int i = 0; i <= Len; i++)
{
if(i & 1)B[i] = ((mod - ifac[i]) % mod + mod) % mod;
else B[i] = ifac[i]; A[i] = ifac[i];
}
NTT(A, len, 1); NTT(B, len, 1);
NTT(F, len, 1); NTT(G, len, 1);
for(int i = 0; i < len; i++)
{
F[i] = A[i] * F[i] % mod;
G[i] = A[i] * G[i] % mod;
}
NTT(F, len, -1); NTT(G, len, -1);
for(int i = 0; i <= Len; i++)
H[i] = F[i] % mod * G[i] % mod * fac[i] % mod;
NTT(H, len, 1);
for(int i = 0; i < len; i++)
H[i] = H[i] * B[i] % mod;
NTT(H, len, -1);
for(int i = 0; i <= n + m; i++)
printf("%lld ", H[i]);
return 0;
}

多项式除法

给定一个nn次多项式F(x)F(x)和一个mm次多项式G(x)G(x), 求多项式A(x),B(x)A(x),B(x),满足:

  • A(x)A(x)次数为nmn - mB(x)B(x)次数小于mm
  • F(x)=A(x)G(x)+B(x)F(x) = A(x) \ast G(x) + B(x)

所有运算在模998244353998244353下进行。

定义一种让多项式反转的操作为A(x)=xnA(1x)A'(x) = x^n A(\frac{1}{x}),然后化简式子

F(x)=A(x)G(x)+B(x)xnF(1x)=xnmA(1x)xmG(1x)+xnm+1xm1B(1x)F(x)=A(x)G(x)+xnm+1B(x)F(x)A(x)G(x)(modxnm+1)A(x)F(x)G1(x)(modxnm+1)F(x) = A(x) \ast G(x) + B(x) \\ x^n F(\frac{1}{x}) = x^{n - m} A(\frac{1}{x}) \ast x^m G(\frac{1}{x}) + x^{n - m + 1}\ast x ^{m - 1} B(\frac{1}{x}) \\ F'(x) = A'(x) \ast G'(x) + x^{n - m + 1} \ast B'(x) \\ F'(x) \equiv A'(x) \ast G'(x) \pmod {x^{n - m + 1}} \\ A'(x) \equiv F'(x) \ast G'^{-1}(x) \pmod {x^{n - m + 1}}

先进行多项式求逆,然后再推B(x)=F(x)A(x)G(x)B(x) = F(x) - A(x) \ast G(x)

多项式多点求值

咕咕咕

多项式复合函数

F(x),G(x)F(x), G(x),求

H(x)F(G(x))(modxn+1)H(x) \equiv F(G(x)) \pmod {x^{n + 1}}

即:

H(x)i=0n[xi]F(x)×G(x)i(modxn+1)H(x) \equiv \sum_{i = 0} ^n[x^i] F(x) \times G(x)^i \pmod {x^{n + 1}}

998244353998244353取模。

m=nm = \sqrt n,则

i=0n[xi]F(x)G(x)i=i=0m1j=0m1[xim+j]F(x)G(x)im+j=i=0m1G(x)imj=0m1[xim+j]F(x)G(x)j\sum_{i = 0} ^ n [x^i] F(x)G(x) ^ i = \sum_{i = 0} ^ {m - 1} \sum_{j = 0} ^ {m - 1}[x^{im + j}]F(x)G(x)^{im + j} = \sum_{i = 0}^{m - 1}G(x)^{im}\sum_{j = 0} ^ {m - 1}[x^{im+j}]F(x)G(x)^j

然后预处理G(x)imG(x)^{im}G(x)jG(x)^j,其它的直接暴力计算, 时间复杂度O(n2+nnlogn)O(n^2 + n \sqrt n \log n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#include <bits/stdc++.h>

using namespace std;

namespace Poly// 使用NTT实现
{
#define int long long
#define vec vector <int>
const int mod = 998244353; // 模数
const int g = 3; // 原根
const int gi = 332748118; // 逆元
const int N = 8e4 + 10; // size

int save[3][32];
int recover[N];

int qpow(int a, int b)
{
int t = 1;
while(b != 0)
{
if(b & 1)t = t * a % mod;
a = a * a % mod; b >>= 1;
}
return t;
}// 快速幂

int inv(int x) { return qpow(x, mod - 2);}// 逆元

void prework()
{
for(int i = 1, k = 1; i <= 20; i++, k <<= 1)
{
save[0][i] = qpow(g, (mod - 1) / (k << 1));
save[1][i] = qpow(gi, (mod - 1) / (k << 1));
save[2][i] = inv(k);
}
}

void init(int n, int m, int &len)
{
len = 1; int cnt = 0;
while(len <= (n + m))len <<= 1, cnt ++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
}// 初始化

void NTT(vec &a, int len, int type)
{
for(int i = 0; i < len; i++)
if(i < recover[i])swap(a[i], a[recover[i]]);
int cnt = 1;
for(int k = 1; k < len; k <<= 1, cnt++)
{
int x = (type == 1) ? save[0][cnt] : save[1][cnt];
for(int i = 0; i < len; i += (k << 1))
{
int w = 1;
for(int j = 0; j < k; j++)
{
int y = a[i + j] % mod;
int z = w * a[i + j + k] % mod;
a[i + j] = (y + z) % mod;
a[i + j + k] = ((y - z) % mod + mod) % mod;
w = w * x % mod;
}
}
}
if(type == -1)
{
int iv = save[2][cnt];
for(int i = 0; i < len; i++)
a[i] = a[i] * iv % mod;
}
}// NTT

struct poly
{
vector <int> v; int len;
poly(){v.resize(N); len = 0;}
void clear(int n){v.clear(); v.resize(N); len = n;}
void length(int n){len = n;}
void memset0(int l, int r){for(int i = l; i < r; i++)v[i] = 0;}
void print(int n){for(int i = 0; i < n; i++)printf("%lld ", v[i]); printf("\n");}

friend poly operator + (poly A, poly B)
{
A.length(max(A.len, B.len));
for(int i = 0; i <= A.len; i++)
A.v[i] = (A.v[i] + B.v[i]) % mod;
return A;
}

friend poly operator - (poly A, poly B)
{
A.length(max(A.len, B.len));
for(int i = 0; i <= A.len; i++)
A.v[i] = ((A.v[i] - B.v[i]) % mod + mod) % mod;
return A;
}

friend poly operator * (poly A, poly B)
{
int len; init(A.len, B.len, len);
NTT(A.v, len, 1), NTT(B.v, len, 1);
for(int i = 0; i < len; i++)
A.v[i] = (A.v[i] * B.v[i]) % mod;
NTT(A.v, len, -1); A.len += B.len;
return A;
}
};

vec tmp;

void inverse(poly &A, poly &B, int n)
{
if(n == 1){B.v[0] = inv(A.v[0]);return;}
inverse(A, B, (n + 1) >> 1);
int len; init(n, n, len);
tmp.clear(); tmp.resize(len);
for(int i = 0; i < n; i++) tmp[i] = A.v[i];
NTT(tmp, len, 1), NTT(B.v, len, 1);
for(int i = 0; i < len; i++)
B.v[i] = (2 - B.v[i] * tmp[i] % mod + mod) % mod * B.v[i] % mod;
NTT(B.v, len, -1);
for(int i = n; i < len; i++)B.v[i] = 0;
}// 乘法逆

void diff(poly &A, poly &B, int n)
{
for(int i = 1; i < n; i++)
B.v[i - 1] = i * A.v[i] % mod;
B.v[n - 1] = 0; B.length(n);
}// 多项式求导

void integ(poly &A, poly &B, int n)
{
for(int i = 1; i < n; i++)
B.v[i] = A.v[i - 1] * inv(i) % mod;
B.v[0] = 0; B.length(n);
}// 多项式积分

poly C, D, E, F, G, H, I;

void Ln(poly &A, poly &B, int n)
{
E.clear(n); F.clear(n);
diff(A, E, n); inverse(A, F, n);
E = E * F;
integ(E, B, n); B.length(n);
}// 多项式ln函数

void Exp(poly &A, poly &B, int n)
{
if(n == 1){B.v[0] = 1; return;}
Exp(A, B, (n + 1) >> 1);
int len; init(n, n, len);
C.clear(n); D.clear(n); C.v[0] = 1;
Ln(B, D, n); C = B * (C + A - D);
for(int i = 0; i < n; i++)B.v[i] = C.v[i];
for(int i = n; i < len; i++)B.v[i] = 0;
}// 多项式exp函数

const int inv2 = inv(2);

void Sqrt(poly &A, poly &B, int n)
{
if(n == 1){B.v[0] = 1; return;}
Sqrt(A, B, (n + 1) >> 1);
int len; init(n, n, len);
G.clear(n); H.clear(n); inverse(B, H, n);
for(int i = 0; i < n; i++)G.v[i] = A.v[i];
NTT(H.v, len, 1), NTT(B.v, len, 1), NTT(G.v, len, 1);
for(int i = 0; i < len; i++)
B.v[i] = (B.v[i] + G.v[i] * H.v[i] % mod) % mod * inv2 % mod;
NTT(B.v, len, -1);
for(int i = n; i < len; i++)B.v[i] = 0;
}// 多项式开根

void Pow(poly &A, poly &B, int n, int k)
{
I.clear(n); Ln(A, I, n);
for(int i = 0; i < n; i++)(I.v[i] *= k) %= mod;
Exp(I, B, n);
}// 多项式幂函数

#undef int
}

using namespace Poly;

int n, m;

poly Gpow[200], Gm[200];

int main()
{
prework();
cin >> n >> m;
n = n + 1; m = m + 1;
poly F, G, H;
F.clear(n); G.clear(m);
for(int i = 0; i < n; i++)
cin >> F.v[i];
for(int i = 0; i < m; i++)
cin >> G.v[i];
int l = sqrt(n) + 1;
Gpow[0].v[0] = 1; Gm[0].v[0] = 1;
Gpow[0].length(n); Gm[0].length(n);
Gpow[1] = G; Gpow[1].length(n);
for(int i = 2; i <= l; i++)
{
Gpow[i] = Gpow[i - 1] * G;
Gpow[i].length(n);
Gpow[i].memset0(n, n << 1);
}
Gm[1] = Gpow[l]; Gm[1].length(n);
for(int i = 2; i <= l; i++)
{
Gm[i] = Gm[i - 1] * Gpow[l];
Gm[i].length(n);
Gm[i].memset0(n, n << 1);
}
poly A; A.clear(n); H.clear(n);
for(int i = 0; i < l; i++)
{
A.memset0(0, n << 1); A.length(n);
for(int j = 0; j < l; j++)
{
for(int k = 0; k < n; k++)
A.v[k] = (A.v[k] + F.v[i * l + j] * Gpow[j].v[k]) % mod;
}
A = A * Gm[i]; A.length(n); H = H + A;
}
for(int i = 0; i < n; i++)
cout << H.v[i] << " ";
return 0;
}

多项式复合逆

F(x)F(x),求

G(F(x))x(modxn)G(F(x)) \equiv x \pmod {x ^ n}

998244353998244353取模。

有拉格朗日反演公式:

[xn]G(x)=1n[xn1](xF(x))n[x^n]G(x) = \frac{1}{n}[x^{n - 1}](\frac{x}{F(x)})^n

这个是求单项系数的, 需要求的是所有系数。

利用解决多项式复合函数的方法, 令m=nm = \sqrt n, 则

G(x)=i=1n(1i[xi1](xF(x))i)=i=0m1j=1m(1im+j[xim+j1](xF(x))im+j)xim+j=i=0m1j=1m(1im+j[xim+j1](xF(x))im(xF(x))j)xim+jG(x) = \sum_{i = 1}^{n}\bigg ( \frac{1}{i}[x^{i - 1}] (\frac{x}{F(x)})^i \bigg) = \sum_{i = 0}^{m - 1}\sum_{j = 1} ^{m }\bigg( \frac{1}{im + j}[x^{im + j - 1}](\frac{x}{F(x)})^{im + j} \bigg)x^{im + j} = \sum_{i = 0}^{m - 1}\sum_{j = 1} ^ m \bigg( \frac{1}{im + j} [x^{im + j - 1}](\frac{x}{F(x)})^im (\frac{x}{F(x)})^j \bigg)x^{im + j}

然后就是和多项式复合函数的一样的处理方法,时间复杂度O(n2+nlogn)O(n^2 + \sqrt n \log n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <bits/stdc++.h>

using namespace std;

#define int long long
#define vec vector <int>

const int N = 8e4 + 10;
const int mod = 998244353;
const int g = 3;
const int gi = 332748118;

struct poly
{
vector <int> v;
poly(){v.resize(N);}
};

int save[3][32];
int recover[N];

int qpow(int a, int b)
{
int t = 1;
while(b != 0)
{
if(b & 1)t = t * a % mod;
a = a * a % mod; b >>= 1;
}
return t;
}

int inv(int x)
{
return qpow(x, mod - 2);
}

void prework()
{
for(int i = 1, k = 1; i <= 20; i++, k <<= 1)
{
save[0][i] = qpow(g, (mod - 1) / (k << 1));
save[1][i] = qpow(gi, (mod - 1) / (k << 1));
save[2][i] = inv(k);
}
}

void NTT(vec &a, int len, int type)
{
for(int i = 0; i < len; i++)
if(i < recover[i])swap(a[i], a[recover[i]]);
int cnt = 1;
for(int k = 1; k < len; k <<= 1, cnt++)
{
int x = (type == 1) ? save[0][cnt] : save[1][cnt];
for(int i = 0; i < len; i += (k << 1))
{
int w = 1;
for(int j = 0; j < k; j++)
{
int y = a[i + j] % mod;
int z = w * a[i + j + k] % mod;
a[i + j] = (y + z) % mod;
a[i + j + k] = ((y - z) % mod + mod) % mod;
w = w * x % mod;
}
}
}
if(type == -1)
{
int iv = save[2][cnt];
for(int i = 0; i < len; i++)
a[i] = a[i] * iv % mod;
}
}

vec tmp;

void inverse(poly &A, poly &B, int n)
{
if(n == 1){B.v[0] = inv(A.v[0]);return;}
inverse(A, B, (n + 1) >> 1);
int len = 1, cnt = 0;
while(len <= (n << 1))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
tmp.clear(); tmp.resize(len);
for(int i = 0; i < n; i++) tmp[i] = A.v[i];
NTT(tmp, len, 1), NTT(B.v, len, 1);
for(int i = 0; i < len; i++)
B.v[i] = (2 - B.v[i] * tmp[i] % mod + mod) % mod * B.v[i] % mod;
NTT(B.v, len, -1);
for(int i = n; i < len; i++)B.v[i] = 0;
}

int n;

poly Finv[200], Fm[200];

signed main()
{
poly F; prework();
cin >> n;
for(int i = 0; i < n; i++)
cin >> F.v[i];
for(int i = 0; i < n; i++)
F.v[i] = F.v[i + 1];
n = n - 1;
int m = sqrt(n) + 1;
poly A; inverse(F, A, n);
Finv[0].v[0] = Fm[0].v[0] = 1;
int len = 1, cnt = 0;
while(len <= (n << 1))len <<= 1, cnt++;
for(int i = 0; i < len; i++)
recover[i] = (recover[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
Finv[1] = A; NTT(A.v, len, 1);
for(int i = 2; i <= m; i++)
{
NTT(Finv[i - 1].v, len, 1);
for(int j = 0; j < len; j++)
Finv[i].v[j] = Finv[i - 1].v[j] * A.v[j] % mod;
NTT(Finv[i].v, len, -1);
NTT(Finv[i - 1].v, len, -1);
for(int j = n; j < (n << 1); j++)
Finv[i].v[j] = 0;
}
A = Finv[m]; NTT(A.v, len, 1);
Fm[1] = Finv[m];
for(int i = 2; i <= m; i++)
{
NTT(Fm[i - 1].v, len, 1);
for(int j = 0; j < len; j++)
Fm[i].v[j] = Fm[i - 1].v[j] * A.v[j] % mod;
NTT(Fm[i].v, len, -1);
NTT(Fm[i - 1].v, len, -1);
for(int j = n; j < (n << 1); j++)
Fm[i].v[j] = 0;
}
poly G; bool res = false;
for(int i = 0; i <= m; i++)
{
for(int j = 1; j <= m; j++)
{
if(i * m + j - 1 > n)
{
res = true;
break;
}
int sum = 0;
for(int k = 0; k <= i * m + j - 1; k++)
sum = (sum + Finv[j].v[k] * Fm[i].v[i * m + j - 1 - k] % mod) % mod;
G.v[i * m + j] = sum * inv(i * m + j) % mod;
}
if(res)break;
}
for(int i = 0; i <= n; i++)
cout << G.v[i] << " ";
return 0;
}
作者

Jekyll_Y

发布于

2022-09-26

更新于

2023-03-02

许可协议

评论