해당 글에서는 임의의 소수 p와 p미만의 자연수 n에 대해 n!(mod p)를 빠르게 구하는 방법에 대해 설명하고자 한다. 기본적인 풀이는 반복문으로 O(n)으로 푸는 것이다. 여기서 실제 실행시간에서 n에 곱해지는 상수를 줄일 수도 있고 fft를 이용하여 시간복잡도 O(sqrt(n)log^2(n)) 또는 O(sqrt(n)log(n))으로 구할 수도 있다. 이 글에서는 n에 곱해지는 상수를 줄이는 방법에 대해 설명할 것이다. 하지만 이 글에서 설명하는 방법은 fft를 이용하는 방법에도 적용이 가능하다.
상수를 줄이는 방법은 윌슨의 정리, barrett reduction, SIMD, 르장드르의 정리로 4가지가 있다. 4가지 방법은 모두 동시에 적용시킬 수 있다.
1. 윌슨의 정리
윌슨의 정리에 의해 (p-1)!≡-1(mod p)임을 이용해 n>p/2인 경우 n!≡(p-1)!*((n+1)*(n+2)*…*(p-1))^(-1)(mod p)≡-1*((n+1)*(n+2)*…*(p-1))^(-1)(mod p)≡(-1)^(p-n)*((p-n-1)*(p-n-2)*…*1)^(-1)(mod p)≡(-1)^(p-n)*((p-n-1)!)^(-1)(mod p)이므로 (p-n-1)!(mod p)을 계산해 모듈러 역원으로 n!(mod p)를 계산할 수 있다. n>p/2이면 p-n-1<p-p/2-1<=p/2이므로 최악의 경우에서의 실행시간을 반으로 줄일 수 있다. 또한 n이 p/2에 가까울 때가 최악의 경우가 된다.
2. barrett reduction
barrett reduction은 m이 고정되어 있고 어떤 수 a를 m으로 나눈 나머지를 여러 번 구해야 할 때 곱셈과 비트 시프트 연산을 이용해 빠르게 구하는 방법이다. 먼저 a와 m보다 큰 2의 거듭제곱 2^n에 대해 2^n을 m으로 나눈 몫과 나머지를 q, r이라 하자. 이 때 0<=r<m이고 q=(2^n-r)/m이다. 여기서 floor(a*q/2^n)이 floor(a/m)의 근삿값이다. 왜냐하면 a*q/2^n=(a/m)*(2^n-r)/2^n=a/m-a*r/(m*2^n)이고 a<=2^n, r<m이므로 0<=a*r/(m*2^n)<=1, a/m-1<=a*q/2^n<=a/m, floor(a/m)-1<=floor(a*q/2^n)<=floor(a/m)이기 때문이다. 그러므로 a를 m으로 나눈 나머지를 구하기 위해서는 a-floor(a*q/2^n)*m을 구한 다음 그 값이 m이상이면 m을 빼주면 된다.
여기서 2^n을 가능한 모든 a와 m보다 크도록 잡고 q=floor(2^n/m)을 미리 구하고 2^n으로 나눌 때 비트 시프트 연산을 이용하면 기본적인 나눗셈에 비해 속도가 향상된다. 일반적으로 m미만의 두 수를 곱한 값을 m으로 나눈 나머지를 구하는 경우가 많고 이때 두 수의 곱이 m^2미만이므로 2^n을 m^2이상으로 잡으면 충분하다.
또한 연산 과정 중에서 모듈러를 반복하여 적용하는 경우, a-floor(a*q/2^n)*m이 m이상임에 상관없이 m을 빼지 않고 대략적인 나머지를 계산한 후 최종적인 결과를 구할 때 정확한 나머지를 구해주어도 된다. 이때 시간이 어느 정도 단축 될 수 있다. 2^n은 4*m^2이상으로 잡으면 충분하다.
두 수의 곱셈을 m으로 나눈 나머지를 barrett reduction으로 구하는 코드는 다음과 같다:
#include <stdio.h>
#include <stdlib.h>
long long n,n2,q,m;
long long mul(long long b,long long c){//b*c%m을 계산하는 함수
long long a=b*c;
a=a-(a*q>>n)*m;//a-floor(a*q/2^n)*m, 비트 시프트 연산을 이용하여 a*q/2^n을 계산
if(a>=m)a-=m;//대략적인 나머지를 계산하는 경우 생략 가능
return a;//b*c%m
}
int main()
{
long long b,c;
n=40;
n2=(long long)1<<n;//n2=2^n
scanf("%lld %lld %lld",&b,&c,&m);//b<m,c<m,a=b*c<m^2<=2^n
q=n2/m;//q=floor(2^n/m)
printf("%lld",mul(b,c));//mul(b,c)=b*c%m
return 0;
}
3. SIMD
SIMD는 여러 값에 대해 같은 연산을 빠르게 수행하는 방식이다. #pragma GCC optimize("Ofast"), #pragma GCC target("avx,avx2") 등을 코드에 추가하고 정수 배열을 2의 거듭제곱 크기로 정의해서 배열의 크기를 128, 256, 512비트 정도로 만들고 배열의 각 원소들에 동일한 연산을 적용시키는 것을 나열하면 컴파일러가 최적화해 줄 것이다. 예를 들어 SIMD를 이용하여 n! mod m을 계산할 때는 다음과 같이 할 수 있다:
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2")
#include <stdio.h>
#include <stdlib.h>
int main()
{
long long a[8]={1,1,1,1,1,1,1,1},b,n,m,i;
scanf("%lld %lld",&n,&m);
b=1;
for(;n%8;n--){
b=b*n%m;
}
for(i=0;i<n;i+=8){//SIMD
a[0]=a[0]*(i+1)%m;
a[1]=a[1]*(i+2)%m;
a[2]=a[2]*(i+3)%m;
a[3]=a[3]*(i+4)%m;
a[4]=a[4]*(i+5)%m;
a[5]=a[5]*(i+6)%m;
a[6]=a[6]*(i+7)%m;
a[7]=a[7]*(i+8)%m;
}
for(i=0;i<8;i++){
b=b*a[i]%m;
}
printf("%lld",b);
return 0;
}
SIMD에 대한 자세한 설명은 다음 링크를 참고하자:
SIMD in PS - ICPC Seoul Regional 2021 L. Trio | JusticeHui가 PS하는 블로그
4. 르장드르의 정리
1부터 n까지의 수를 곱할 때 2의 배수와 배수가 아닌 것을 별도로 생각하면 2의 배수를 곱한 것은 2^(n/2)*(n/2)!이고 2의 배수가 아닌 것을 곱한 것은 n이하의 홀수의 곱이다. 여기서 (n/2)!에 대해 같은 방법을 다시 적용해 주는 식으로 반복하면 n!=2^(n/2+n/4+n/8+…)*(n이하의 홀수의 곱)*(n/2이하의 홀수의 곱)*(n/4이하의 홀수의 곱)*…이므로 2^(n/2+n/4+n/8+…)을 따로 구해주고 (n/2^k이하의 홀수의 곱)=(n/2^k이하 n/2^(k+1)초과의 홀수의 곱)*(n/2^(k+1)이하의 홀수의 곱)임을 이용해 누적하여 (n/2^k이하 n/2^(k+1)초과의 홀수의 곱)을 곱해 (n/2^k이하의 홀수의 곱)을 구하고, 이를 누적하여 곱해주면 결과적으로 n!을 구할 수 있다. (n/2^k이하 n/2^(k+1)초과의 홀수의 곱)을 구할 때 2씩 건너뛰어 주면 홀수만 곱할 수 있다. 실행시간은 n이하의 홀수만 곱하므로 절반으로 줄어든다.
2뿐만 아니라 다른 소수들에 대해서도 동시에 추가하여 적용할 수 있다. 소수 p를 추가하여 적용하는 경우를 생각해 보면 p^(n/p+n/p^2+n/p^3+…)을 미리 구해주고 n/2, n/4, …에 대해 {(n/2)/p, (n/2)/p^2, ...}, {(n/4)/p, (n/4)/p^2, …}, …을 구해서 정렬한 다음 마찬가지로 (이전 수+1부터 해당 수까지의 곱)을 누적하여 곱해서 (해당 수 이하의 곱)을 구하고, 이를 다시 누적해서 곱해주면 된다. (이전 수+1부터 해당 수까지의 곱)을 구할 때는 적용한 소수들 모두와 서로소인 것만 곱하면 된다. 방법은 mod(소수들의 곱(이를 p’이라 하자))에서 소수들의 배수가 아닌 수들을 따로 저장한 다음, 곱하려는 수들의 구간 양끝이 p’의 배수가 될 때까지 나이브하게 조건문으로 서로소인지 확인하며 곱하며 구간을 줄여주고, 2중 반복문으로 구간의 시작에서 끝까지 p’씩 건너뛰면서 저장한 수들을 더한 값을 누적하여 곱해주면 된다. 2, 3, 5에 대해서만 적용해도 p’이하의 소수들의 배수가 아닌 수가 euler_phi(p’)=8으로 8의 배수이므로 SIMD를 쉽게 적용해 줄 수 있다. 또한 이 경우 시간이 8/30(약0.27)배가 된다.
설명한 방법 중 일부를 이용해 조건이 p<=10^9이고 제한시간이 3초인 boj 17467 N! mod P (2)를 해결할 수 있고, 모든 방법을 이용해 조건이 p<=10^10이고 제한시간이 3초인 boj 17468 N! mod P (3)을 해결할 수 있다. p의 제한과 최적화 정도에 따른 실행시간은 다음과 같다(2번 방법은 별도의 설명이 없으면 대략적인 나마지를 구하는 방법을 사용한 것이고, 시간초과는 실행시간이 3초 이상인 경우이다).
일부 방법 적용시 | 최적화 없이 | 1번 방법 | 1, 2번 방법(2번은 정확한 나머지 계산) | 1, 2번 방법 | 1, 2, 3번 방법 |
0<n<p<1e8 | 1260ms | 628ms | 272ms | 244ms | 60ms |
0<n<p<1e9 | 시간초과 | 시간초과 | 2748ms | 2452ms | 584ms |
0<n<p<1e10 | 시간초과 | 시간초과 | 시간초과 | 시간초과 | 시간초과 |
모든 방법 적용시 | 4번에서 2이하의 소수 적용 | 4번에서 3이하의 소수 적용 | 4번에서 5이하의 소수 적용 | 4번에서 7이하의 소수 적용 | 4번에서 11이하의 소수 적용 | 4번에서 13이하의 소수 적용 | 4번에서 17이하의 소수 적용 |
0<n<p<1e8 | 28ms | 20ms | 16ms | 12ms | 12ms | 16ms | 44ms |
0<n<p<1e9 | 296ms | 200ms | 160ms | 136ms | 128ms | 124ms | 188ms |
0<n<p<1e10 | 시간초과 | 2508ms | 2228ms | 1912ms | 1752ms | 1660ms | 2336ms |
실행시간은 대체로 p의 제한에 비례한다(제한이 1e10인 경우에는 __int128을 사용했기 때문에 시간이 추가로 더 걸렸다). 4번 방법에서 소수 p를 추가했을 때 실행시간이 대략 (p-1)/p배가 된다. 또한 4번 방법에서 p의 제한이 1e8인 경우에는 7까지의 소수를, 1e9또는 1e10인 경우에는 13까지의 소수를 적용시키는 것이 가장 빠르다. 4번 방법에서 소수를 추가로 적용시킬 때 일정 시점 이후로 실행시간이 늘어나는 것은 4번 방법에서 전처리할 때의 시간이 늘어나기 때문인 것 같다.
fft를 이용하는 방법은 두 가지가 있다. 첫 번째 방법은 multipoint evaluation 알고리즘을 이용하여 시간복잡도 O(sqrt(n)log^2(n))(또는 O(sqrt(n)log^(3/2)(n)))에 해결하는 방법이고 두 번째 방법은 라그랑주 보간법을 이용하여 시간복잡도 O(sqrt(n)log(n))에 해결하는 방법이다. 자세한 방법은 다음 링크들을 참고하자:
첫 번째 방법: 다항식 나눗셈과 다중계산
두 방법 모두에 설명한 4가지 방법을 전부 적용할 수 있다. 윌슨의 정리는 팩토리얼을 계산하기 전에 적용시킬 수 있고 barrett reduction은 ntt(정수론적 fft)를 이용하는 경우 fft에도 적용시킬 수 있다. SIMD 또한 fft에 적용시킬 수 있다. 르장드르의 정리는 계산하는 다항식을 일부 소수들의 배수가 곱해지지 않도록 바꾸고(ex)소수가 2, 3인 경우 (6x+1)(6x+5)(6x+7)(6x+11)…) fft로 계산한 뒤 르장드르의 정리를 이용할 때 각 구간을 계산할 때 계산해 놓은 값이 필요하면 이용하면 된다.
'코딩 > 문제' 카테고리의 다른 글
USACO 2024 December Contest Silver 1, 2번 문제 (0) | 2024.12.28 |
---|---|
Boj 13428 배열의 합 (3) | 2024.12.19 |