Skip to content

Commit

Permalink
优化内存占用
Browse files Browse the repository at this point in the history
  • Loading branch information
EndlessCheng committed Apr 1, 2022
1 parent 2045486 commit 64ffb19
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions copypasta/math_ntt.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ https://oi-wiki.org/math/poly/ntt/
1004535809 = 479*2^21+1, g = 3, invG = 334845270
998244353 = 119*2^23+1, g = 3, invG = 332748118
167772161 = 5*2^25+1, g = 3, invG = 55924054
P-1 包含大量因子 2,便于分治
模数任意的解决方案 http://blog.miskcoo.com/2015/04/polynomial-multiplication-and-fast-fourier-transform
任意模数 NTT https://www.luogu.com.cn/problem/P4245
Expand Down Expand Up @@ -90,13 +91,6 @@ https://codeforces.com/problemset/problem/958/F3
todo https://codeforces.com/contest/438/problem/E
*/

type ntt struct {
n int
invN int64
omega []int64
omegaInv []int64
}

const P = 998244353

func _pow(x int64, n int) (res int64) {
Expand All @@ -110,18 +104,24 @@ func _pow(x int64, n int) (res int64) {
return
}

func newNTT(n int) *ntt {
var omega, omegaInv [31]int64 // 多开一点空间

func init() {
const g, invG = 3, 332748118
omega := make([]int64, n+1)
omegaInv := make([]int64, n+1)
for i := 1; i <= n; i <<= 1 {
omega[i] = _pow(g, (P-1)/i)
omegaInv[i] = _pow(invG, (P-1)/i)
for i := 1; i < len(omega); i++ {
omega[i] = _pow(g, (P-1)/(1<<i))
omegaInv[i] = _pow(invG, (P-1)/(1<<i))
}
return &ntt{n, _pow(int64(n), P-2), omega, omegaInv}
}

func (t *ntt) transform(a, omega []int64) {
type ntt struct {
n int
invN int64
}

func newNTT(n int) ntt { return ntt{n, _pow(int64(n), P-2)} }

func (t ntt) transform(a, omega []int64) {
for i, j := 0, 0; i < t.n; i++ {
if i > j {
a[i], a[j] = a[j], a[i]
Expand All @@ -133,9 +133,10 @@ func (t *ntt) transform(a, omega []int64) {
}
}
}
for l := 2; l <= t.n; l <<= 1 {
for l, li := 2, 1; l <= t.n; l <<= 1 {
m := l >> 1
wn := omega[l]
wn := omega[li]
li++
for st := 0; st < t.n; st += l {
b := a[st:]
for i, w := 0, int64(1); i < m; i++ {
Expand All @@ -148,12 +149,12 @@ func (t *ntt) transform(a, omega []int64) {
}
}

func (t *ntt) dft(a []int64) {
t.transform(a, t.omega)
func (t ntt) dft(a []int64) {
t.transform(a, omega[:])
}

func (t *ntt) idft(a []int64) {
t.transform(a, t.omegaInv)
func (t ntt) idft(a []int64) {
t.transform(a, omegaInv[:])
for i, v := range a {
a[i] = v * t.invN % P
}
Expand Down

0 comments on commit 64ffb19

Please sign in to comment.