扩展中国剩余定理

       关于扩展中国剩余定理及扩展中国剩余定理的复习笔记:

中国剩余定理(CRT)

       中国剩余定理是求解如下同余方程组的算法:

{xc1  (mod m1)xc2  (mod m2)xc3  (mod m3)           xcn  (mod mn)
       m都互质时,我们使用中国剩余定理(CRT)。

       对于一个同余方程组,我们从简单的入手:

{xc1  (mod m1)xc2  (mod m2)

       可以写成:{x=c1+k1m1x=c2+k2m2

       联立式子:x=c1+k1m1=c2+k2m2
k1m1k2m2=c2c1
       因为m1m2互质,所以对于任意c2c1的取值,肯定有一队合法解k1k2.

       然而对于求形如ax+by=c的解,就是扩展欧几里得干的事情了:

1
2
3
4
void exgcd(int a,int b,int &d,int &x,int &y){
if(!b){d=a;x=1;y=0;return;}
exgcd(b,a%b,d,y,x);y-=x*(a/b);
}

       它求出的xy,既是ax+by=gcd(a,b)=1的解。

       知道了k1m1+k2(m2)=1的解为k1k2,那么就容易得到k1m1k2m2=c2c1的解了:

k1m1+k2(m2)=1

k1m1k2m2=1

(k1(c2c1))m1(k2(c2c1))m2=c2c1

       k1=k1(c2c1)k2=k2(c2c1)

       现在我们带回去,就可以得到:
x=c1+(k1(c2c1))m1=c2+(k2(c2c2))m2

       至此我们的答案就出来了,如果遇到很多的方程,我们不妨就这样合并下去,就出来了,不过问题来了,中国剩余定理(CRT)只适用于当m都互质时,适用范围比较小,下面我们马上引入扩展中国剩余定理(EXCRT),模板还是记它吧,就不贴中国剩余定理(CRT)的代码了。

扩展中国剩余定理(EXCRT)

       对于一个同余方程组,同样我们从简单的入手:

{xc1  (mod m1)xc2  (mod m2)

       同理联立:

x=c1+k1m1=c2+k2m2
k1m1k2m2=c2c1

       因为m1m2不一定互质,所以不能直接用扩展欧几里得了,当然了,我们可以先把他化成互质的:

k1m1gcd(m1,m2)+k2m2gcd(m1,m2)=c2c1gcd(m1,m2)

       套入扩展欧几里得,得到特解:

k1m1gcd(m1,m2)+k2m2gcd(m1,m2)=1

k1(c2c1)gcd(m1,m2)2m1k2(c2c1)gcd(m1,m2)2m2=c2c1gcd(m1,m2)

k1(c2c1)gcd(m1,m2)m1k2(c2c1)gcd(m1,m2)m2=c2c1

       带回去,就可以得到:
x=c1+k1(c2c1)gcd(m1,m2)m1=c2+k2(c2c1)gcd(m1,m2)m2

       那么这样就很显然了,依次合并下去就好了,答案就出来了,当上面的除法不能整除的时候,就是无解。

快速乘

       这个东西对于CRT很重要,很容易在计算两个数的积的时候就爆了long long,所以我们需要用到类似快速幂的做法,变算边取模:

1
2
3
4
5
long long multi(long long a,long long b,long long p){
a=(a%p+p)%p;b=(b%p+p)%p;long long ans=0;
for(;a;a>>=1,b=(b*2)%p)if(a&1)ans=(ans+b)%p;
return ans;
}

       还有O(1)的:

1
2
3
4
long long mul(long long a,long long b,long long mod){
a%=mod,b%=mod;
return ((a*b-(long long)((long long)((long double)a/mod*b+1e-3)*mod))%mod+mod)%mod;
}

最后给出完整模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
int n;
long long x,y,lcm;
long long m[N],c[N];
long long multi(long long a,long long b,long long p){
a=(a%p+p)%p;b=(b%p+p)%p;long long ans=0;
for(;a;a>>=1,b=(b*2)%p)if(a&1)ans=(ans+b)%p;
return ans;
}
long long exgcd(long long a,long long b,long long &x,long long &y){
if(!b){x=1,y=0;return a;}
long long val=exgcd(b,a%b,x,y);
long long t=x;x=y;y=t-a/b*y;return val;
}
long long excrt(long long*m,long long*c,int n){
for(int i=1;i<n;i++){
long long val=exgcd(m[i],m[i+1],x,y);
lcm=m[i]/val*m[i+1];
m[i+1]=lcm;
// if((c[i+1]-c[i])%val)return -1;
val=multi(x,(c[i+1]-c[i])/val,lcm);
c[i+1]=(multi(m[i],val,lcm)+c[i])%lcm;
}
return (c[n]%m[n]+m[n])%m[n];
}

例题

【P4777 【模板】扩展中国剩余定理(EXCRT)】

ImagineOrz大佬的模板,数据还是挺强的,卡了我很久。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstdlib>
#include<cstring>
#include<string>
#include<queue>
#include<map>
#include<set>
#include<stack>
#include<cmath>
#include<cctype>
using namespace std;
const int inf=0x7fffffff;
const double eps=1e-10;
const double pi=acos(-1.0);
//char buf[1<<15],*S=buf,*T=buf;
//char getch(){return S==T&&(T=(S=buf)+fread(buf,1,1<<15,stdin),S==T)?0:*S++;}
inline long long read(){
long long x=0,f=1;char ch;ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-') f=0;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch&15);ch=getchar();}
if(f)return x;else return -x;
}
int n;
long long x,y,lcm;
long long m[100055],c[100055];
long long multi(long long a,long long b,long long p){
a=(a%p+p)%p;b=(b%p+p)%p;long long ans=0;
for(;a;a>>=1,b=(b*2)%p)if(a&1)ans=(ans+b)%p;
return ans;
}
long long exgcd(long long a,long long b,long long &x,long long &y){
if(!b){x=1,y=0;return a;}
long long val=exgcd(b,a%b,x,y);
long long t=x;x=y;y=t-a/b*y;return val;
}
long long excrt(long long*m,long long*c,int n){
for(int i=1;i<n;i++){
long long val=exgcd(m[i],m[i+1],x,y);
lcm=m[i]/val*m[i+1];
m[i+1]=lcm;
val=multi(x,(c[i+1]-c[i])/val,lcm);
c[i+1]=(multi(m[i],val,lcm)+c[i])%lcm;
}
return (c[n]%m[n]+m[n])%m[n];
}
int main(){
n=(int)read();
for(int i=1;i<=n;i++)
m[i]=read(),c[i]=read();
long long ans=excrt(m,c,n);
printf("%lld\n",ans);
return 0;
}

【P4774 [NOI2018]屠龙勇士】

      虽然当时当场就看出来是同余方程组了,不过还是因为快速乘坑了好久,还是做少了,太菜了。考点比较多,有点回忆不起来了,还是贴一下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstdlib>
#include<cstring>
#include<string>
#include<set>
#include<map>
#include<cmath>
using namespace std;
const long long inf=0x7fffffff;
const double eps=1e-10;
const double pi=acos(-1.0);
const int N=1e5+10;
inline long long read(){
long long x=0,f=1;char ch;ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
long long n,m,T;
long long f[N],p[N],a[N],w[N],x[N];
long long gcd(long long a,long long b)
{return b?gcd(b,a%b):a;}
long long exgcd(long long a,long long b,long long &x,long long &y){
if(!b){x=1,y=0;return a;}
long long res=exgcd(b,a%b,y,x);
y-=x*(a/b);
return res;
}
long long inv(long long a,long long b){
long long x=0,y=0,g=exgcd(a,b,x,y);
if(g>1)return -1;
return (x+b)%b;
}
long long fast_multi(long long a,long long b,long long p) {
a=(a%p+p)%p;
b=(b%p+p)%p;
long long ans=0;
for(;a;a>>=1,b=(b<<1)%p)
if(a&1LL)ans=(ans+b)%p;
return ans;
}
bool CRT(long long w1,long long p1,long long w2,long long p2,long long &w,long long &p){
long long x,y,z=w2-w1,g=exgcd(p1,p2,x,y);
if(z%g)return 0;
long long t=z/g;
x=fast_multi(x,t,p2/g);
p=p1/g*p2;
w=((w1+fast_multi(x,p1,p))%p+p)%p;
return 1;
}
long long solve(){
for(int i=1;i<=n;i++){
long long g=gcd(a[i],gcd(f[i],p[i]));
f[i]/=g,p[i]/=g,a[i]/=g;
long long Inv=inv(f[i],p[i]);
if(Inv<0)return -1LL;
x[i]=fast_multi(a[i],Inv,p[i]);
}
long long W=x[1],P=p[1];
for(int i=2;i<=n;i++)
if(!CRT(W,P,x[i],p[i],W,P))return -1LL;
for(int i=1;i<=n;i++){
long long val=(a[i]+f[i]-1)/f[i];
if(val<=W)continue;
long long k=(val-W+P-1)/P;
W+=k*P;
}
return W;
}
multiset<long long> S;
int main()
{
// freopen("dragon.in","r",stdin);
// freopen("dragon.out","w",stdout);
T=read();
while(T--){
n=read(),m=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<=n;i++)p[i]=read();
for(int i=1;i<=n;i++)w[i]=read();
S.clear();
while(m--)S.insert(read());
for(int i=1;i<=n;i++){
multiset<long long> :: iterator p=S.begin();
if((*p)<a[i])p=--S.upper_bound(a[i]);
f[i]=*p,S.erase(p);
S.insert(w[i]);
}
printf("%lld\n",solve());
}
fclose(stdin);
fclose(stdout);
return 0;
}