【Python】形式的冪級数(FPS)(一部未完成)
【注意】このライブラリは一部の機能が未完成です。また、expの計算が遅く改善の余地があります。Pythonによる完成度の高い形式的冪級数ライブラリを探されている方は、yosupo judgeのSubmission一覧画面で①形式的冪級数(Formal Power Series)関連の問題の、②言語はPythonで、③速度がそこそこ速いACコードでフィルターをかけて絞って探すのがおすすめです。
機能
形式的冪級数の逆元、exp、log等の計算、スカラー演算、他の形式的冪級数との四則演算(+、ー、*、/)ができます。
使用例(mod=998244353)
f = FPS([0,1,2,3,4]) g = FPS([3,5,2,2,1]) # 基本演算 print("f+g = ", f+g) print("f-g = ", f-g) print("f*g = ", f*g) print("f/g = ", f/g) print("g^5 = ", g**5) # f+g = FPS([3, 6, 4, 5, 5], d=5) # f-g = FPS([998244350, 998244349, 0, 1, 3], d=5) # f*g = FPS([0, 3, 11, 21, 33], d=5) # f/g = FPS([0, 1, 2, 3, 4], d=5) # g^5 = FPS([243, 2025, 7560, 17460, 29760], d=5) # スカラー演算 print("f+5=", f+5) print("f-5=", f-5) print("f*5=", f*5) #print("f/5=", f/5) # f+5= FPS([5, 6, 7, 8, 9], d=5) # f-5= FPS([998244348, 998244349, 998244350, 998244351, 998244352], d=5) # f*5= FPS([0, 5, 10, 15, 20], d=5) # 逆元 print("inv(g)", g.inv()) # log print("log(g)", g.log()) # exp print("exp(f)", f.exp()) # 累積和 print("cumsum(f)", f.cumsum(), f/FPS([1,-1])) # 差分 print("diff(f)", f.diff(), f*FPS([1,-1])) # inv(g) FPS([332748118, 776412274, 147888053, 345072121, 694252247], d=5) # log(g) FPS([0, 665496237, 610038215, 616200219, 828789292], d=5) # exp(f) FPS([1, 1, 499122179, 166374064, 291154613], d=5) # cumsum(f) FPS([0, 1, 3, 6, 10], d=5) FPS([0, 1, 3, 6, 10], d=5) # diff(f) FPS([1, 1, 1, 1], d=4) FPS([0, 1, 1, 1, 1], d=5)
Verify
- 逆元
- https://judge.yosupo.jp/submission/61635
- log
- https://judge.yosupo.jp/submission/61634
参考になった記事やコードなど
- maspyさんのブログ
hotmanさんの記事
Nyaanさんのライブラリと説明
https://nyaannyaan.github.io/library/fps/formal-power-series.hpp.html
Kiriさんのコード(NTT部分)
- https://atcoder.jp/contests/arc115/submissions/21143410
ソースコード
mod, g, ig = 998244353, 3, 332748118 W = [pow(g, (mod - 1) >> i, mod) for i in range(24)] iW = [pow(ig, (mod - 1) >> i, mod) for i in range(24)] class FPS: def __init__(self, arr=[0]): self.arr = arr self.d = len(arr) def get(self, i): return self.arr[i] __call__ = get def __neg__(self): return FPS([-self.arr[i]%mod for i in range(self.d)]) def __add__(self, other): if isinstance(other, int): return FPS([(self(i)+other)%mod for i in range(self.d)]) return FPS([(self(i)+other(i))%mod for i in range(min(self.d, other.d))]) def __sub__(self, other): if isinstance(other, int): return FPS([(self(i)-other)%mod for i in range(self.d)]) return FPS([(self(i)-other(i))%mod for i in range(min(self.d, other.d))]) def __mul__(self, other): if isinstance(other, int): return FPS([self(i)*other%mod for i in range(self.d)]) return FPS(self._convolve(self.arr, other.arr)[:max(self.d, other.d)]) def __truediv__(self, other): A,B = self.arr[:], other.arr[:] Q = [] for i in range(max(len(A)-len(B)+1, 0)): Q.append(A[i]*pow(B[0], mod-2, mod)%mod) for j in range(len(B)): A[i+j] -= Q[-1]*B[j] A[i+j] %= mod R = [] for i in range(len(A)): if A[i]: R = A[i:] break if not Q: Q = [0] if not R: R = [0] return FPS(Q+R) def __pow__(self, k): out = FPS([1]) base = self.copy() for i in range(k.bit_length()): if (k>>i)&1: out = out * base base = base*base return out def divmod(self, other): A,B = self.arr[:], other.arr[:] Q = [] for i in range(max(len(A)-len(B)+1, 0)): Q.append(A[i]*pow(B[0], mod-2, mod)%mod) for j in range(len(B)): A[i+j] -= Q[-1]*B[j] A[i+j] %= mod R = [] for i in range(len(A)): if A[i]: R = A[i:] break if not Q: Q = [0] if not R: R = [0] return FPS(Q), FPS(R) def __repr__(self): return f"FPS({self.arr}, d={self.d})" def inv(self): f = self.arr g = [pow(f[0], mod-2, mod)] while True: n = len(g) if n >= len(f): return FPS(g[:len(f)]) a = [g[i]*2%mod if i < n else 0 for i in range(n*2)] b = self._convolve(f, self._convolve(g, g))[:n*2] g = [(a[i]-b[i]) % mod for i in range(n*2)] def differentiate(self): return FPS([(self.arr[i]*i)%mod for i in range(1, self.d)]) def integrate(self, C=0): return FPS([C]+[self.arr[i]//(i+1) if self.arr[i]%(i+1)==0 else (self.arr[i] * pow(i+1, mod-2, mod))%mod for i in range(self.d)]) def log(self): return (self.differentiate() * self.inv()).integrate().modx(self.d) def exp(self): assert self.arr[0] == 0 g = FPS([0 if i else 1 for i in range(self.d)]) e = g.copy() for i in range(1,self.d.bit_length()+1): g = g * (-g.log() + self + e) g = g.modx(1<<(i+1)) g = g.modx(self.d) return g def modx(self, d0): return FPS([self.arr[i] if i < self.d else 0 for i in range(d0)]) def copy(self): return FPS(self.arr[:]) def tolist(self, copy=False): if copy: return self.arr[:] return self.arr def cumsum(self): res = [self.arr[0]] for i in range(1,self.d): res.append((res[-1]+self.arr[i])%mod) return FPS(res) def diff(self): return FPS([(self.arr[i+1]-self.arr[i])%mod for i in range(self.d-1)]) def reverse(self, inplace=False): if inplace: self.arr.reverse() return return FPS(self.arr[::-1]) def _convolve(self, a, b): def fft(f): for l in range(k, 0, -1): d = 1 << l - 1 U = [1] for i in range(d): U.append(U[-1] * W[l] % mod) for i in range(1 << k - l): for j in range(d): s = i * 2 * d + j t = s + d f[s], f[t] = (f[s] + f[t]) % mod, U[j] * (f[s] - f[t]) % mod def ifft(f): for l in range(1, k + 1): d = 1 << l - 1 U = [1] for i in range(d): U.append(U[-1] * iW[l] % mod) for i in range(1 << k - l): for j in range(d): s = i * 2 * d + j t = s + d f[s], f[t] = (f[s] + f[t] * U[j]) % mod, (f[s] - f[t] * U[j]) % mod n0 = len(a) + len(b) - 1 if len(a) < 50 or len(b) < 50: ret = [0] * n0 if len(a) > len(b): a, b = b, a for i, aa in enumerate(a): for j, bb in enumerate(b): ret[i+j] = (ret[i+j] + aa * bb) % mod return ret k = (n0).bit_length() n = 1 << k la,lb = len(a),len(b) a = [a[i] if i < la else 0 for i in range(n)] b = [b[i] if i < lb else 0 for i in range(n)] fft(a), fft(b) a = [(a[i]*b[i])%mod for i in range(n)] ifft(a) invn = pow(n, mod - 2, mod) a = [a[i]*invn %mod for i in range(n0)] return a