拉格朗日插值法

$\ \ \ \ \ \ \ \,$关于拉格朗日插值法的复习笔记:

$\ \ \ \ \ \ \,$拉格朗日插值法,是解决多项式点值表达式函数值的问题的算法,具体而言,问题如下:

$\ \ \ \ \ \ \,$已知在二维平面上,一多项式$f$的函数图像经过 $n+1$ 个点 $(x_i,y_i)$ ,既我们知道 $f(x_i)=y_i$,求$f(k)$的值。

初步分析

$\ \ \ \ \ \ \,$显然,经过 $n+1$ 个点的多项式函数一定是$n$次函数,那么也很显然的,我们可以列出一个 $n$ 元方程组来解决这个问题,形如:

$\begin{cases}y_1=a_1\cdot x_1^n+a_2\cdot x_1^{n-1}+a_3\cdot x_1^{n-2}+…+a_{n+1}\\y_2=a_1\cdot x_2^n+a_2\cdot x_2^{n-1}+a_3\cdot x_2^{n-2}+…+a_{n+1}\\\ \ \ \ \ \ \ \ \ \ \ \ \ …\\y_n=a_1\cdot x_n^n+a_2\cdot x_n^{n-1}+a_3\cdot x_n^{n-2}+…+a_{n+1}\end{cases}$

$\ \ \ \ \ \ \,$直观的,我们可以高斯消元来做,复杂度$O(n^3)$,复杂度多半已上天,那么我们如何快速处理呢?

插值法

$\ \ \ \ \ \ \,$我们现在知道的是 $f(x_i)=y_i$,那么我们想怎么快速表达出这个多项式,我们设定一种函数 $S_i$:

$S_i(x)=[x=x_i]$

$\ \ \ \ \ \ \,$它的意义是只有当 $x$ 为 $x_i$ 函数值才为 $1$,其他为 $0$,那么显然有:

$f=\sum_{i=1}^n y_i\cdot S_i$

$\ \ \ \ \ \ \,$很显然的,$S_i$ 是个 $n$ 次多项式,而我们的答案也就是:

$f(k)=\sum_{i=1}^n y_i\cdot S_i(k)$

$\ \ \ \ \ \ \,$遗憾地告诉你,光靠这个办法是不能还原$f$的函数图像的,只能得到近似图像,所以插值法是有一定误差的。

$\ \ \ \ \ \ \,$现在我们简化了问题,如何求$S_i$ 是个 $n$ 次多项式。

$\ \ \ \ \ \ \,$其实满足$S_i(x)$的函数取值合法挺简单的,主要问题是如何满足他是 $n$ 次多项式的事实,那么我们就先把它化作下面的形态:

$S_i(k)=\prod_{j=1,j\neq \rm something}^{n+1}a_jk+b_j$

$\ \ \ \ \ \ \,$如何就卡住了,我们不知道怎么办,但是还记得吗,我们只能得到近似图像,所以我们只需要满足$S_i(x)$的图像大约在$[x=x_i]$就行了,拉格朗日给了一种解法:

$S_i(k)=\prod_{j=1,j\neq i}^{n+1}\frac{k-x_j}{x_i-x_j}$

$\ \ \ \ \ \ \ \,$具体证明可以看这里:【 VictoryCzt_拉格朗日插值法学习笔记】

$\ \ \ \ \ \ \ \,$所以得到拉格朗日公式($n$个点),复杂度是$O(n^2)$的:

$f(k)=\sum_{i=1}^{n} y_i\prod_{j=1,j\neq i}^{n}\frac{k-x_j}{x_i-x_j}$

代码实现:

$\ \ \ \ \ \ \,$普通的拉格朗日差值法其实代码很简单了:

1
2
3
4
5
6
7
8
9
10
11
double Lagrange(double *x,double *y,double n,double k){
double top,bot,ret=0;
for(double i=1;i<=n;i++){
top=1.0;bot=1.0;
for(double j=1;j<=n;++j)if(i!=j)
top*=k-x[j],
bot*=x[i]-x[j];
ret+=y[i]*top/bot;
}
return ret;
}

$\ \ \ \ \ \ \,$有些时候呢,我们的 $x_i$ 是连续的,所以说公式变形为:

$f(k)=\sum_{i=1}^{n} y_i\prod_{j=1,j\neq i}^{n}\frac{k-j}{i-j}$

$\ \ \ \ \ \ \,$我们令:

$pre_i=\prod_{j=1}^{i}(k-j)=pre_{i-1}\times(k-j)$

$suf_i=\prod_{j=i}^{n}(k-j)=suf_{i+1}\times(k-j)$

$\ \ \ \ \ \ \,$显然对于这两个函数是可以$O(n)$预处理的。

$\ \ \ \ \ \ \,$原公式可以化成:

$f(k)=\sum_{i=1}^{n} y_i \frac{pre_{i-1}\times suf_{i+1}}{(i-1)!\times(n-i)!}$

$\ \ \ \ \ \ \,$注意分母是有符号需要判断的,$n-i$为奇数时,分母为负。

$\ \ \ \ \ \ \,$显然阶乘也可以预处理,于是乎复杂度变成了$O(n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
long long pre[N],suf[N],fac[N];
long long Lagrange(long long *y,long long n,long long k){
long long top,bot,ret=0;
pre[0]=suf[n+1]=fac[0]=1;
for(int i=1;i<=n;i++)pre[i]=pre[i-1]*(k-i);
for(int i=n;i>=1;i--)suf[i]=suf[i+1]*(k-i);
for(int i=2;i<=n;i++)fac=fac*i;
for(int i=1;i<=n;i++){
top*=pre[i-1]*suf[i+1],
bot*=fac[i-1]*fac[n-i];
if((n-i)&1) ret-=y[i]*top/bot;
else ret+=y[i]*top/bot;
}
return ret;
}

例题【Codeforces Round #492 F. Cowmpany Cowmpensation】

$\ \ \ \ \ \ \,$题目大意:

$\ \ \ \ \ \ \,$给你一棵 $n$ 个节点的树,以 $1$ 为根节点,现在让你给每个节点分配一个权值$∈[1,D]$,使得每个节点的权值不超过他的父亲节点( $1$ 号节点除外),问一共有多少种分配方式。$(1≤n≤3000, 1≤D≤10^9)$。

$\ \ \ \ \ \ \,$显然直观的有一个二维树形dp的做法,一维表示当前节点,一维表示这个节点的取值情况,来描述方案数:

$f_{i,j}=\prod_{s\in Son_i}f_{s,j}+f_{i,j-1}$

$\ \ \ \ \ \ \,$答案就是$f_{1,D}$了。

$\ \ \ \ \ \ \,$复杂度为 $O(nD)$ 的,过不了,但是我们观察dp式子,可以发现……这个家伙是个多项式吧?继续大胆猜测,这个是个 $n$ 维多项式。

$\ \ \ \ \ \ \,$所以我们用dp来预处理出 $f_{1,i}\ ,\ {i\in[1,n+1]}$,然后使用拉格朗日差值法,便可以求出 $f_{1,D}$ 了,复杂度是 $O(n^2)$ 的。

$\ \ \ \ \ \ \,$代码如下,取模真的取死人啊

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cstdio>
#include<vector>
#include<string>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
using namespace std;
const int inf=0x7fffffff;
const double eps=1e-10;
const double pi=acos(-1.0);
//char buf[1<<15],*S=buf,*T=buf;
//char getch(){return S==T&&(T=(S=buf)+fread(buf,1,1<<15,stdin),S==T)?0:*S++;}
inline int read(){
int x=0,f=1;char ch;ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-') f=0;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch&15);ch=getchar();}
if(f)return x;else return -x;
}
const int N=3030;
const long long mod=1e9+7;
long long power(long long a,long long b){
long long ans=1ll;
for(;b;b>>=1,a=a*a%mod)if(b&1ll)ans=ans*a%mod;
return ans;
}
int n;
long long D;
vector<int> G[N];
long long f[N][N];
void dfs(int u){
for(int i=1;i<=n+1;i++)f[u][i]=1;
for(int i=0;i<G[u].size();i++){
dfs(G[u][i]);
for(int j=1;j<=n+1;j++)
f[u][j]=(f[u][j]*f[G[u][i]][j])%mod;
}
for(int x=1;x<=n+1;x++)f[u][x]=(f[u][x]+f[u][x-1])%mod;
}
long long pre[N],suf[N],fac[N];
long long Lagrange(long long *y,long long n,long long k){
long long top,bot,ret=0,Fac;
pre[0]=suf[n+1]=Fac=1ll;
for(long long i=1;i<=n;i++)pre[i]=(pre[i-1]*(k-i+mod)%mod)%mod;
for(long long i=n;i>=1;i--)suf[i]=(suf[i+1]*(k-i+mod)%mod)%mod;
for(long long i=2;i<=n;i++)Fac=Fac*i%mod;
fac[n]=power(Fac,mod-2);
for(int i=n-1;i>=0;i--)fac[i]=fac[i+1]*1ll*(i+1)%mod;
for(int i=1;i<=n;i++){
top=(pre[i-1]*suf[i+1]%mod)%mod,
bot=(fac[i-1]*fac[n-i]%mod)%mod;
if((n-i)&1)ret=(ret-(y[i]*top%mod*bot%mod)+mod)%mod;
else ret=(ret+(y[i]*top%mod*bot%mod))%mod;
}
return (ret+mod)%mod;
}
int main()
{
n=read();scanf("%I64d",&D);D%=mod;
for(int i=2,a;i<=n;i++)
a=read(),G[a].push_back(i);
dfs(1);
if(D<=1ll*n+1ll)printf("%I64d\n",f[1][D]);
else printf("%I64d\n",Lagrange(f[1],n+1,D));
return 0;
}