【Python】形式的冪級数(FPS)(一部未完成)

【注意】このライブラリは一部の機能が未完成です。また、expの計算が遅く改善の余地があります。Pythonによる完成度の高い形式的冪級数ライブラリを探されている方は、yosupo judgeのSubmission一覧画面で①形式的冪級数(Formal Power Series)関連の問題の、②言語はPythonで、③速度がそこそこ速いACコードでフィルターをかけて絞って探すのがおすすめです。

機能

形式的冪級数の逆元、exp、log等の計算、スカラー演算、他の形式的冪級数との四則演算(+、ー、*、/)ができます。

使用例(mod=998244353)

f = FPS([0,1,2,3,4])
g = FPS([3,5,2,2,1])

# 基本演算
print("f+g = ", f+g)
print("f-g = ", f-g)
print("f*g = ", f*g)
print("f/g = ", f/g)
print("g^5 = ", g**5)

# f+g =  FPS([3, 6, 4, 5, 5], d=5)
# f-g =  FPS([998244350, 998244349, 0, 1, 3], d=5)
# f*g =  FPS([0, 3, 11, 21, 33], d=5)
# f/g =  FPS([0, 1, 2, 3, 4], d=5)
# g^5 =  FPS([243, 2025, 7560, 17460, 29760], d=5)




# スカラー演算
print("f+5=", f+5)
print("f-5=", f-5)
print("f*5=", f*5)
#print("f/5=", f/5)

# f+5= FPS([5, 6, 7, 8, 9], d=5)
# f-5= FPS([998244348, 998244349, 998244350, 998244351, 998244352], d=5)
# f*5= FPS([0, 5, 10, 15, 20], d=5)




# 逆元
print("inv(g)", g.inv())

# log
print("log(g)", g.log())

# exp
print("exp(f)", f.exp())

# 累積和
print("cumsum(f)", f.cumsum(), f/FPS([1,-1]))

# 差分
print("diff(f)", f.diff(), f*FPS([1,-1]))

# inv(g) FPS([332748118, 776412274, 147888053, 345072121, 694252247], d=5)
# log(g) FPS([0, 665496237, 610038215, 616200219, 828789292], d=5)
# exp(f) FPS([1, 1, 499122179, 166374064, 291154613], d=5)
# cumsum(f) FPS([0, 1, 3, 6, 10], d=5) FPS([0, 1, 3, 6, 10], d=5)
# diff(f) FPS([1, 1, 1, 1], d=4) FPS([0, 1, 1, 1, 1], d=5)

Verify

  • 逆元

   - https://judge.yosupo.jp/submission/61635

  • log

   - https://judge.yosupo.jp/submission/61634

参考になった記事やコードなど

ソースコード

mod, g, ig = 998244353, 3, 332748118
W = [pow(g, (mod - 1) >> i, mod) for i in range(24)]
iW = [pow(ig, (mod - 1) >> i, mod) for i in range(24)]

class FPS:
    def __init__(self, arr=[0]):
        self.arr = arr
        self.d = len(arr)
        
    def get(self, i):
        return self.arr[i]
    
    __call__ = get
    
    def __neg__(self):
        return FPS([-self.arr[i]%mod for i in range(self.d)])
        
    def __add__(self, other):
        if isinstance(other, int): return FPS([(self(i)+other)%mod for i in range(self.d)])
        return FPS([(self(i)+other(i))%mod for i in range(min(self.d, other.d))])
    
    def __sub__(self, other):
        if isinstance(other, int): return FPS([(self(i)-other)%mod for i in range(self.d)])
        return FPS([(self(i)-other(i))%mod for i in range(min(self.d, other.d))])
    
    def __mul__(self, other):
        if isinstance(other, int): return FPS([self(i)*other%mod for i in range(self.d)])
        return FPS(self._convolve(self.arr, other.arr)[:max(self.d, other.d)])
    
    def __truediv__(self, other):
            A,B = self.arr[:], other.arr[:]
            Q = []
            for i in range(max(len(A)-len(B)+1, 0)):
                Q.append(A[i]*pow(B[0], mod-2, mod)%mod)
                for j in range(len(B)):
                    A[i+j] -= Q[-1]*B[j]
                    A[i+j] %= mod
            R = []
            for i in range(len(A)):
                if A[i]: 
                    R = A[i:]
                    break
            if not Q: Q = [0]
            if not R: R = [0]
            return FPS(Q+R)
        
    def __pow__(self, k):
        out = FPS([1])
        base = self.copy()
        for i in range(k.bit_length()):
            if (k>>i)&1: out = out * base
            base = base*base
        return out
    
    def divmod(self, other):
        A,B = self.arr[:], other.arr[:]
        Q = []
        for i in range(max(len(A)-len(B)+1, 0)):
            Q.append(A[i]*pow(B[0], mod-2, mod)%mod)
            for j in range(len(B)):
                A[i+j] -= Q[-1]*B[j]
                A[i+j] %= mod
        R = []
        for i in range(len(A)):
            if A[i]: 
                R = A[i:]
                break
        if not Q: Q = [0]
        if not R: R = [0]
        return FPS(Q), FPS(R)
    
    def __repr__(self):
        return f"FPS({self.arr}, d={self.d})"
    
    def inv(self):
        f = self.arr
        g = [pow(f[0], mod-2, mod)]
        while True:
            n = len(g)
            if n >= len(f): return FPS(g[:len(f)])
            a = [g[i]*2%mod if i < n else 0 for i in range(n*2)]
            b = self._convolve(f, self._convolve(g, g))[:n*2]
            g = [(a[i]-b[i]) % mod for i in range(n*2)]
    
    def differentiate(self):
        return FPS([(self.arr[i]*i)%mod for i in range(1, self.d)])
    
    def integrate(self, C=0):
        return FPS([C]+[self.arr[i]//(i+1) if self.arr[i]%(i+1)==0 else (self.arr[i] * pow(i+1, mod-2, mod))%mod for i in range(self.d)])
    
    def log(self):
        return (self.differentiate() * self.inv()).integrate().modx(self.d)
    
    def exp(self):
        assert self.arr[0] == 0
        g = FPS([0 if i else 1 for i in range(self.d)])
        e = g.copy()
        for i in range(1,self.d.bit_length()+1):
            g = g * (-g.log() + self + e)
            g = g.modx(1<<(i+1))
        g = g.modx(self.d)
        return g
    
    def modx(self, d0):
        return FPS([self.arr[i] if i < self.d else 0 for i in range(d0)])
            
    def copy(self):
        return FPS(self.arr[:])
    
    def tolist(self, copy=False):
        if copy: return self.arr[:]
        return self.arr
    
    def cumsum(self):
        res = [self.arr[0]]
        for i in range(1,self.d): res.append((res[-1]+self.arr[i])%mod)
        return FPS(res)
    
    def diff(self):
        return FPS([(self.arr[i+1]-self.arr[i])%mod for i in range(self.d-1)])
    
    def reverse(self, inplace=False):
        if inplace:
            self.arr.reverse()
            return
        return FPS(self.arr[::-1])
    
    def _convolve(self, a, b):
        def fft(f):
            for l in range(k, 0, -1):
                d = 1 << l - 1
                U = [1]
                for i in range(d):
                    U.append(U[-1] * W[l] % mod)
                for i in range(1 << k - l):
                    for j in range(d):
                        s = i * 2 * d + j
                        t = s + d
                        f[s], f[t] = (f[s] + f[t]) % mod, U[j] * (f[s] - f[t]) % mod
        def ifft(f):
            for l in range(1, k + 1):
                d = 1 << l - 1
                U = [1]
                for i in range(d):
                    U.append(U[-1] * iW[l] % mod)
                for i in range(1 << k - l):
                    for j in range(d):
                        s = i * 2 * d + j
                        t = s + d
                        f[s], f[t] = (f[s] + f[t] * U[j]) % mod, (f[s] - f[t] * U[j]) % mod
        n0 = len(a) + len(b) - 1
        if len(a) < 50 or len(b) < 50:
            ret = [0] * n0
            if len(a) > len(b): a, b = b, a
            for i, aa in enumerate(a):
                for j, bb in enumerate(b):
                    ret[i+j] = (ret[i+j] + aa * bb) % mod
            return ret

        k = (n0).bit_length()
        n = 1 << k
        la,lb = len(a),len(b)
        a = [a[i] if i < la else 0 for i in range(n)]
        b = [b[i] if i < lb else 0 for i in range(n)]
        fft(a), fft(b)
        a = [(a[i]*b[i])%mod for i in range(n)]
        ifft(a)
        invn = pow(n, mod - 2, mod)
        a = [a[i]*invn %mod for i in range(n0)]
        return a