Atcoder Beginner Contest (ABC) 221 E - LEQ

問題文

想定解法はBIT(Binary Indexed Tree)でした。最初からBITだと分かって解いたというよりは、手元でノートに色々書き込みながら考察を進めていく中でなんとなく「BIT使えそうだな」→「あ、本当に行けそう」と徐々に確信が強くなり、解くことができました。以下では、本番で私が考察した順番でなるべく丁寧にステップを踏んで説明しています。長々と書いているので、結論だけを見たい方は正しい操作手順以降をご覧ください。

部分列の性質:左端と右端が決まればあとは自由

例があった方がわかりやすいので、以下の数列で計算することを考えてみましょう。

A = [4, 2, 7, 1, 5, 6, 3]

まず気づくのは、ここから部分列A'を選んだとき、問題文より部分列が正しい部分列として成り立つ条件はA_1' <= A_k'なので、部分列の左端と右端が決まればそれ以外は何でもよいということです。

例えば、A'の左端としてA_1 = 4を、右端としてA_5 = 5を選んだ場合を考えてみます。この時、左端と右端に挟まれたA_2 = 2, A_3 = 7, A_4 = 1はいずれも選んでも選ばなくても部分列の条件を満たします。つまり、

A' = [4, 5](何も選ばない場合)
A' = [4, 2, 5](A_2を選んだ場合)
A' = [4, 7, 5](A_3を選んだ場合)
...
A' = [4, 2, 7, 5](A_2, A_3を選んだ場合)
...
A' = [4, 2, 7, 1, 5](全て選んだ場合)

のいずれもが正しい部分列として成り立ちます。左端と右端以外の要素1つ1つについて選ぶor選ばないの選択ができるので、左端の位置をl, 右端の位置をrとすると、あり得る部分列の個数は2r-l-1となります。

計算をまとめたい:右端を固定

次に、計算をまとめることを考えます。

例えば上の例

A = [4, 2, 7, 1, 5, 6, 3]

で、右端をA_5 = 5に固定して考えてみます。この時、左端の位置を1<=l<=4の範囲で移動させてみると、あり得る部分列の個数はそれぞれ

l = 1: 25-1-1 = 8通り
l = 2: 25-2-1 = 4通り
l = 3: 0通り(A_3 = 7 > 5で正しい部分列の条件を満たさないため除外)
l = 4: 25-4-1 = 1通り

となります。右端が左端より大きい場合を除外した上で、左端の位置によって場合の数が変わってくることがわかります。

これを一般化して、右端A_rを固定して考えたとき、うまくまとめて計算するには、①既出の数字(A_i, i < r)のうち、A_r以下のものの、②インデックスの位置i(i<r)を踏まえた場合の数の総和の2点が素早く分かればいいということになります。

転倒数の計算に使うBITとの共通点

①既出の数字(A_i, i < r)のうち、A_r以下のものをまとめて計算するのは、BITを使った転倒数の計算と似ています(ただし、転倒数はA_rより大きいものをまとめて計算する)。BITを使った転倒数の計算の詳細については例えば以下の記事を参照してください。

いかたこのたこつぼ 転倒数

以下では、BITのインスタンスをbitとして、bitの配列の左からi番目の要素にxを加える操作をbit.add(i, x), 左からl番目からr番目までの要素の総和を計算する操作をbit.query(l-1,r)(累積和の計算と同じで、左側が開になる半開区間)と書くことにします。

BITを使って、例えば以下のようなことができそうです。

Aを左の要素から順に見ていくことを考えます。BITの配列は最初、全て0で初期化されています。

BIT: [0, 0, 0, 0, 0, 0, 0]

まず最初にA_1=4の場所(1-indexed)に1を加算します(bit.add(4, 1))。

(A_1 = 4を加算) BIT: [0, 0, 0, 1, 0, 0, 0]

続いて、A_2 = 2です。2の場所に1を加算する前に、BITの配列を1<=i<=2の範囲の合計値をチェックします(bit.query(0, 2))。この場合は全て0(=既出の2以下の要素の個数は0)なので、何も起こりません。続いて、2の場所に1を加算します(bit.add(2, 1))。

(A_2 = 2を加算) BIT: [0, 1, 0, 1, 0, 0, 0]

続いて、A_3 = 7です。7の場所に1を加算する前に、BITの配列を1<=i<=7の範囲の合計値をチェックします(bit.query(0, 7))。この場合は2です。つまり、既出の要素の中で7以下の要素の個数は2ということがわかります。

しかし、これだと単に既出の要素の個数をまとめて計算しているだけで、②インデックスの位置i(i<r)を踏まえた場合の数の総和の計算ができていません。

インデックスの位置の情報を入れたい:BIT配列に加算する際に位置の情報を入れる

上の例では、加算の際に毎回1を加算していました。加算する1の代わりに、要素のインデックスの位置iの情報を入れたらどうなるでしょうか。

再び、Aを左の要素から順に見ていくことを考えます。BITの配列は最初、全て0で初期化されています。

BIT: [0, 0, 0, 0, 0, 0, 0]

まず最初にA_1=4の場所(1-indexed)にA_1のインデックスである1を加算します(bit.add(4, 1))。これは上の例と同じです。

BIT: [0, 0, 0, 1, 0, 0, 0]

続いて、A_2 = 2です。まず加算する前に、BITの配列を1<=i<=2の範囲の合計値をチェックします(bit.query(0, 2))。この場合は全て0(=既出の2以下の要素の個数は0)なので、何も起こりません。これも、上の例と同じです。続いて、2の場所にA_2のインデックスの2を加算します(bit.add(2, 2))。

BIT: [0, 2, 0, 1, 0, 0, 0]

続いて、A_3 = 7です。まず加算する前に、BITの配列を1<=i<=7の範囲の合計値をチェックします(bit.query(0, 7))。この場合は2+1=3が計算結果として出てきます。

ここで、A_1 = 4については1、A_2=2については2がBIT配列に代入されており、近い要素ほど値が大きくなってしまっていることに気づくと思います。一方、今やりたいことは最初に見た

l = 1: 25-1-1 = 8通り
l = 2: 25-2-1 = 4通り
l = 3: 0通り(A_3 = 7 > 5で正しい部分列の条件を満たさないため除外)
l = 4: 25-4-1 = 1通り

のように、近い要素ほど値を小さくしたいのです。

近い要素ほど値を小さくしたい:マイナスにしてみる

上の例で加算するインデックスの位置情報にマイナスをかけることで、やりたいことに一歩近づく気がします。やってみましょう。

BITの配列は最初、全て0で初期化されています。

BIT: [0, 0, 0, 0, 0, 0, 0]

まず最初にA_1=4の場所(1-indexed)にA_1のインデックスである1にマイナスをかけた-1加算します(bit.add(4, -1))。

BIT: [0, 0, 0, -1, 0, 0, 0]

続いて、A_2 = 2です。まず加算する前に、BITの配列を1<=i<=2の範囲の合計値をチェックします(bit.query(0, 2))。この場合は全て0(=既出の2以下の要素の個数は0)なので、何も起こりません。これも、上の例と同じです。続いて、2の場所にA_2のインデックスにマイナスをかけた-2を加算します(bit.add(2, -2))。

BIT: [0, -2, 0, -1, 0, 0, 0]

続いて、A_3 = 7です。まず加算する前に、BITの配列を1<=i<=7の範囲の合計値をチェックします(bit.query(0, 7))。この場合は-2+(-1)=-3が計算結果として出てきます。

確かに、近い要素ほど値が小さくなってはいます。しかし、求めたいものは2r-l-1で、2冪の形になっています。

2冪の形にしてみる

上の例からさらに1歩進んで、2x(xはマイナスをかけたインデックスの値)を加算することを考えてみます。

BITの配列は最初、全て0で初期化されています。

BIT: [0, 0, 0, 0, 0, 0, 0]

まず最初にA_1=4の場所(1-indexed)に2-1を加算します(bit.add(4, pow(2, -1)))。

BIT: [0, 0, 0, 2-1, 0, 0, 0]

続いて、(合計値のチェックは0なので省略)A_2=2の場所に2-2を加算します(bit.add(4, pow(2, -2)))。

BIT: [0, 2-2, 0, 2-1, 0, 0, 0]

続いて、A_3 = 7です。まず加算する前に、BITの配列を1<=i<=7の範囲の合計値をチェックします(bit.query(0, 7))。この場合は2-2 + 2-1が計算結果として出てきます。

この結果は正しいでしょうか?

本来であれば、右端をA_3 = 7で固定した場合、ありうる部分列の個数は、左端lを1<=l<=2の範囲で動かしてみると、

l=1: 23-1-1 = 2通り
l=2: 22-1-1 = 1通り

とならないといけません。

毎回、計算結果に2冪をかけて値を調整する

よく観察すると、以下のように、求めたい結果は、BITでの計算結果に同じ2冪の値をかけたものと等しいことがわかります。

  • l=1
    • BITの計算結果:2-2
    • 求めたい結果:21(BITの計算結果に23をかけたもの)
  • l=2
    • BITの計算結果:2-1
    • 求めたい結果:22(BITの計算結果に23をかけたもの)

つまり、BITに入力された値は2の「インデックス(絶対位置)にマイナスをかけた値」乗でしたが、求めたい値は左端と右端の位置関係(相対位置)によって決まるので、毎回右端の絶対位置に対応した2冪の値をかけて相対位置に対応した値に変換してあげる必要があるのです。

mod逆元の利用

注意点として、この問題では計算結果のmodをとった値が求められているので、上の計算結果のうち2の負冪(例えば2-2)についてはmod逆元を使った形で値を持っておく必要があります。

mod逆元の計算方法については例えば記事を参照してください。

正しい操作手順

以上を踏まえて、正しい操作手順は以下になります。まずAは以下です。

A = [4, 2, 7, 1, 5, 6, 3]

BITの配列は最初、全て0で初期化されています。答えとしてans = 0に初期化します。

BIT: [0, 0, 0, 0, 0, 0, 0]

まず最初にA_1=4の場所(1-indexed)に2-1 (mod p)を加算します(bit.add(4, pow(2, -1, p)))。

BIT: [0, 0, 0, 2-1, 0, 0, 0]

続いて、A_2 = 2です。まず加算する前に、BITの配列を1<=i<=2の範囲の合計値をチェックします(bit.query(0, 2))。この場合は全て0(=既出の2以下の要素の個数は0)なので、何も起こりません。続いて、BIT配列のA_2=2の場所に2-2 (mod p)を加算します(bit.add(2, pow(2, -2, p)))。

BIT: [0, 2-2, 0, 2-1, 0, 0, 0]

続いて、A_3 = 7です。まず加算する前に、BITの配列を1<=i<=7の範囲の合計値をチェックします(bit.query(0, 7))。計算結果として2-2+2-1が返ってきますが、これに現在位置3に対応する23をかけて調整し、23(2-2+2-1) = 21 + 22を得ます。これをansに加算します。続いて、BIT配列のA_3=7の場所に2-2 (mod p)を加算します(bit.add(7, pow(2, -3, p)))。

BIT: [0, 2-2, 0, 2-1, 0, 0, 2-3]

続いて、A_4 = 1です。まず加算する前に、BITの配列を1<=i<=1の範囲の合計値をチェックします(bit.query(0, 1))。1は一番小さい値なので計算結果として0が返ってきますが、24をかけても0なので、何も起こりません。続いて、BIT配列のA_4=1の場所に2-4 (mod p)を加算します(bit.add(1, pow(2, -4, p)))。

BIT: [2-4, 2-2, 0, 2-1, 0, 0, 2-3]

続いて、A_5 = 5です。まず加算する前に、BITの配列を1<=i<=5の範囲の合計値をチェックします(bit.query(0, 5))。計算結果として2-4+2-2+2-1が返ってきますが、これに現在位置5に対応する25をかけて調整し、25(2-4+2-2+2-1) = 21 + 23 + 24を得ます。これをansに加算します。続いて、BIT配列のA_5=5の場所に2-5 (mod p)を加算します(bit.add(5, pow(2, -5, p)))。

BIT: [2-4, 2-2, 0, 2-1, 2-5, 0, 2-3]

この調子で最後まで計算します。

Aに同じ数字が複数含まれる場合も問題なし

なお、最終的な答えansは単にBITの計算結果(=要素の総和)のさらに総和をとっているだけなので、例えばAに同じ数字が2つ以上含まれていて、BITの配列の同じ箇所に複数の値の合計値が格納されている場合、例えば

A = [4, 4, 7, 1, 5, 6, 3]

で、2回目の操作で

BIT: [0, 0, 0, 2-1+2-2, 0, 0, 0]

となった場合でも、正しい計算結果が得られます。

提出コード(PyPy)

BITはyaketakeさんの実装をお借りしています。

Pythonでは組み込み関数powを利用して2-i (mod p) (ただしpは素数)はpow(2, -i, p)として計算できますが、PyPyの場合はpowの第二引数に負数を持ってくることができないので、フェルマーの小定理  a^{p-1} \equiv 1 を利用してpow(2, p-i-1, p)として計算しています。

class BIT:
    '''https://tjkendev.github.io/procon-library/python/range_query/bit.html'''
    def __init__(self, n):
        self.n = n
        self.data = [0]*(n+1)
        self.el = [0]*(n+1)
    def sum(self, i):
        s = 0
        while i > 0:
            s += self.data[i]
            i -= i & -i
        return s
    def add(self, i, x):
        self.el[i] += x
        while i <= self.n:
            self.data[i] += x
            i += i & -i
    def query(self, i, j=None):
        if j is None:
            return self.el[i]
        return self.sum(j) - self.sum(i)
      
N=int(input())
*A,=map(int,input().split())
mod = 998244353

# 座標圧縮
v2i = {v:i+1 for i,v in enumerate(sorted(set(A)))}
B = [v2i[a] for a in A]

bit = BIT(N+1)
ans = 0
for i,b in enumerate(B):
  if i:
    res = bit.query(0,b)
    ans += res*pow(2,i-1,mod)
    ans %= mod
  bit.add(b,pow(2,mod-i-1,mod))
  
print(ans)

【Python】形式的冪級数(FPS)(一部未完成)

【注意】このライブラリは一部の機能が未完成です。また、expの計算が遅く改善の余地があります。Pythonによる完成度の高い形式的冪級数ライブラリを探されている方は、yosupo judgeのSubmission一覧画面で①形式的冪級数(Formal Power Series)関連の問題の、②言語はPythonで、③速度がそこそこ速いACコードでフィルターをかけて絞って探すのがおすすめです。

機能

形式的冪級数の逆元、exp、log等の計算、スカラー演算、他の形式的冪級数との四則演算(+、ー、*、/)ができます。

使用例(mod=998244353)

f = FPS([0,1,2,3,4])
g = FPS([3,5,2,2,1])

# 基本演算
print("f+g = ", f+g)
print("f-g = ", f-g)
print("f*g = ", f*g)
print("f/g = ", f/g)
print("g^5 = ", g**5)

# f+g =  FPS([3, 6, 4, 5, 5], d=5)
# f-g =  FPS([998244350, 998244349, 0, 1, 3], d=5)
# f*g =  FPS([0, 3, 11, 21, 33], d=5)
# f/g =  FPS([0, 1, 2, 3, 4], d=5)
# g^5 =  FPS([243, 2025, 7560, 17460, 29760], d=5)




# スカラー演算
print("f+5=", f+5)
print("f-5=", f-5)
print("f*5=", f*5)
#print("f/5=", f/5)

# f+5= FPS([5, 6, 7, 8, 9], d=5)
# f-5= FPS([998244348, 998244349, 998244350, 998244351, 998244352], d=5)
# f*5= FPS([0, 5, 10, 15, 20], d=5)




# 逆元
print("inv(g)", g.inv())

# log
print("log(g)", g.log())

# exp
print("exp(f)", f.exp())

# 累積和
print("cumsum(f)", f.cumsum(), f/FPS([1,-1]))

# 差分
print("diff(f)", f.diff(), f*FPS([1,-1]))

# inv(g) FPS([332748118, 776412274, 147888053, 345072121, 694252247], d=5)
# log(g) FPS([0, 665496237, 610038215, 616200219, 828789292], d=5)
# exp(f) FPS([1, 1, 499122179, 166374064, 291154613], d=5)
# cumsum(f) FPS([0, 1, 3, 6, 10], d=5) FPS([0, 1, 3, 6, 10], d=5)
# diff(f) FPS([1, 1, 1, 1], d=4) FPS([0, 1, 1, 1, 1], d=5)

Verify

  • 逆元

   - https://judge.yosupo.jp/submission/61635

  • log

   - https://judge.yosupo.jp/submission/61634

参考になった記事やコードなど

ソースコード

mod, g, ig = 998244353, 3, 332748118
W = [pow(g, (mod - 1) >> i, mod) for i in range(24)]
iW = [pow(ig, (mod - 1) >> i, mod) for i in range(24)]

class FPS:
    def __init__(self, arr=[0]):
        self.arr = arr
        self.d = len(arr)
        
    def get(self, i):
        return self.arr[i]
    
    __call__ = get
    
    def __neg__(self):
        return FPS([-self.arr[i]%mod for i in range(self.d)])
        
    def __add__(self, other):
        if isinstance(other, int): return FPS([(self(i)+other)%mod for i in range(self.d)])
        return FPS([(self(i)+other(i))%mod for i in range(min(self.d, other.d))])
    
    def __sub__(self, other):
        if isinstance(other, int): return FPS([(self(i)-other)%mod for i in range(self.d)])
        return FPS([(self(i)-other(i))%mod for i in range(min(self.d, other.d))])
    
    def __mul__(self, other):
        if isinstance(other, int): return FPS([self(i)*other%mod for i in range(self.d)])
        return FPS(self._convolve(self.arr, other.arr)[:max(self.d, other.d)])
    
    def __truediv__(self, other):
            A,B = self.arr[:], other.arr[:]
            Q = []
            for i in range(max(len(A)-len(B)+1, 0)):
                Q.append(A[i]*pow(B[0], mod-2, mod)%mod)
                for j in range(len(B)):
                    A[i+j] -= Q[-1]*B[j]
                    A[i+j] %= mod
            R = []
            for i in range(len(A)):
                if A[i]: 
                    R = A[i:]
                    break
            if not Q: Q = [0]
            if not R: R = [0]
            return FPS(Q+R)
        
    def __pow__(self, k):
        out = FPS([1])
        base = self.copy()
        for i in range(k.bit_length()):
            if (k>>i)&1: out = out * base
            base = base*base
        return out
    
    def divmod(self, other):
        A,B = self.arr[:], other.arr[:]
        Q = []
        for i in range(max(len(A)-len(B)+1, 0)):
            Q.append(A[i]*pow(B[0], mod-2, mod)%mod)
            for j in range(len(B)):
                A[i+j] -= Q[-1]*B[j]
                A[i+j] %= mod
        R = []
        for i in range(len(A)):
            if A[i]: 
                R = A[i:]
                break
        if not Q: Q = [0]
        if not R: R = [0]
        return FPS(Q), FPS(R)
    
    def __repr__(self):
        return f"FPS({self.arr}, d={self.d})"
    
    def inv(self):
        f = self.arr
        g = [pow(f[0], mod-2, mod)]
        while True:
            n = len(g)
            if n >= len(f): return FPS(g[:len(f)])
            a = [g[i]*2%mod if i < n else 0 for i in range(n*2)]
            b = self._convolve(f, self._convolve(g, g))[:n*2]
            g = [(a[i]-b[i]) % mod for i in range(n*2)]
    
    def differentiate(self):
        return FPS([(self.arr[i]*i)%mod for i in range(1, self.d)])
    
    def integrate(self, C=0):
        return FPS([C]+[self.arr[i]//(i+1) if self.arr[i]%(i+1)==0 else (self.arr[i] * pow(i+1, mod-2, mod))%mod for i in range(self.d)])
    
    def log(self):
        return (self.differentiate() * self.inv()).integrate().modx(self.d)
    
    def exp(self):
        assert self.arr[0] == 0
        g = FPS([0 if i else 1 for i in range(self.d)])
        e = g.copy()
        for i in range(1,self.d.bit_length()+1):
            g = g * (-g.log() + self + e)
            g = g.modx(1<<(i+1))
        g = g.modx(self.d)
        return g
    
    def modx(self, d0):
        return FPS([self.arr[i] if i < self.d else 0 for i in range(d0)])
            
    def copy(self):
        return FPS(self.arr[:])
    
    def tolist(self, copy=False):
        if copy: return self.arr[:]
        return self.arr
    
    def cumsum(self):
        res = [self.arr[0]]
        for i in range(1,self.d): res.append((res[-1]+self.arr[i])%mod)
        return FPS(res)
    
    def diff(self):
        return FPS([(self.arr[i+1]-self.arr[i])%mod for i in range(self.d-1)])
    
    def reverse(self, inplace=False):
        if inplace:
            self.arr.reverse()
            return
        return FPS(self.arr[::-1])
    
    def _convolve(self, a, b):
        def fft(f):
            for l in range(k, 0, -1):
                d = 1 << l - 1
                U = [1]
                for i in range(d):
                    U.append(U[-1] * W[l] % mod)
                for i in range(1 << k - l):
                    for j in range(d):
                        s = i * 2 * d + j
                        t = s + d
                        f[s], f[t] = (f[s] + f[t]) % mod, U[j] * (f[s] - f[t]) % mod
        def ifft(f):
            for l in range(1, k + 1):
                d = 1 << l - 1
                U = [1]
                for i in range(d):
                    U.append(U[-1] * iW[l] % mod)
                for i in range(1 << k - l):
                    for j in range(d):
                        s = i * 2 * d + j
                        t = s + d
                        f[s], f[t] = (f[s] + f[t] * U[j]) % mod, (f[s] - f[t] * U[j]) % mod
        n0 = len(a) + len(b) - 1
        if len(a) < 50 or len(b) < 50:
            ret = [0] * n0
            if len(a) > len(b): a, b = b, a
            for i, aa in enumerate(a):
                for j, bb in enumerate(b):
                    ret[i+j] = (ret[i+j] + aa * bb) % mod
            return ret

        k = (n0).bit_length()
        n = 1 << k
        la,lb = len(a),len(b)
        a = [a[i] if i < la else 0 for i in range(n)]
        b = [b[i] if i < lb else 0 for i in range(n)]
        fft(a), fft(b)
        a = [(a[i]*b[i])%mod for i in range(n)]
        ifft(a)
        invn = pow(n, mod - 2, mod)
        a = [a[i]*invn %mod for i in range(n0)]
        return a

【Python】全方位木DP (書きかけ)

【注意】この記事は書きかけです。

抽象化全方位木DPのライブラリです。全3バージョンあります。多数の引数が入り乱れている実装で改善の余地がありますが、問題をACできる程度には動きます。

実装にあたっては、すぬけさんによるABCの解説放送(https://www.youtube.com/watch?v=zG1L4vYuGrg)のほか、下記の記事が大変参考になりました。
- https://qiita.com/Kiri8128/items/a011c90d25911bdb3ed3
- https://ei1333.hateblo.jp/entry/2017/04/10/224413
- https://algo-logic.info/tree-dp/
- https://qiita.com/keymoon/items/2a52f1b0fb7ef67fb89e

バージョン1:左右累積バージョン

引数に入れる関数の説明

  • f0(a, b): 次数が1のノード(葉)の演算。
  • f(a,b) -> c: 次数が2のノードiの演算。aはノードi自身のdp配列、bは子ノードのdp配列。
  • merge(a,b) -> c: 字数が3以上のノードiの演算。aはノードi自身のdp配列、bは子ノードのdp配列をまとめたリスト。
  • f_last(a,b): 根で行う最後の演算。
  • e: mergeの際の単位元
  • g(a,b) -> c: 累積値を使ってマージしていく時の二項演算

引数の設定例

典型003:各ノードi=0,...,N-1からの最大の距離max(dist(i, j))

atcoder.jp

def f0(a, b):
    return e

def f(a, b):
    return b+1

def merge(a, res):
    return max(res)+1

f_last = merge

e = 0

def g(a, b):
    return max(a, b)

ABC220F(他のノードの距離の総和sum(dist(i,j)))

atcoder.jp

def f0(a, b):
    return (1,0)

def f(a, b):
    sub0, val0 = b
    return (sub0+1,sub0+val0)

def merge(a, res):
    out = [1,0]
    for sub,val in res:
        out[0] += sub
        out[1] += sub + val
    return out

f_last = merge

e = (0,0)

def g(a, b):
    sub1,val1 = a
    sub2,val2 = b
    return sub1+sub2,val1+val2

ソースコード

# 全方位木DPの中で使うための関数
def bfs(root=0):
    N = len(G)
    visited = [0]*N
    visited[root] = 1
    q = [root]
    d = [0]*N
    prev = [-1]*N
    for v in q:
        for v2 in G[v]:
            if visited[v2]: continue
            visited[v2] = 1
            d[v2] = d[v] + 1
            q.append(v2)
            prev[v2] = v
    return d, prev

def tree_dp(root, f, f0, merge, f_last, e):
    N = len(G)
    d, prev = bfs(root)
    order, _ = zip(*sorted(enumerate(d),key=lambda x:x[1]))
    res = [e for _ in range(N)]
    for v in reversed(order[1:]):
        if len(G[v])==1:
            res[v] = f0(res[v], e)
        elif len(G[v])==2:
            for v2 in G[v]:
                if v2 != prev[v]: break
            res[v] = f(res[v], res[v2])
        else:
            res_tmp = [res[v2] for v2 in G[v] if v2 != prev[v]]
            res[v] = merge(res[v], res_tmp)
    res_tmp = [res[v2] for v2 in G[root]]
    res[root] = f_last(res[root], res_tmp)
    return res, order, d, prev

# 全方位木DP本体
def rerooting(root, f0, f, merge, f_last, e, g):
    N = len(G)
    # 左右から累積するときにインデックスへの変換用
    G2i = [{} for i in range(N)]
    for i in range(N):
        G[i].sort()
        G2i[i] = {v:j for j,v in enumerate(G[i])}
    res, order, d, prev = tree_dp(root, f, f0, merge, f_last, e)
    # 左右からの累積用の配列
    cumleft = [[e for _ in range(len(G[i])+1)] for i in range(N)]
    cumright = [[e for _ in range(len(G[i])+1)] for i in range(N)]
    res2 = [e]*N
    # 根の結果は同じ
    res2[root] = res[root]
    
    # 根の累積
    for v in G[root]:
        iv = G2i[root][v]
        cumleft[root][iv+1] = cumright[root][iv] = res[v]
    
    # 根の子ノードを先に計算
    if len(G[root])==1:
        v = G[root][0]
        iroot = G2i[v][root]
        cumleft[v][iroot+1] = cumright[v][iroot] = f0(_, _)
        for i in range(len(G[root])):
            cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i])
            cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1])
    elif len(G[root])==2:
        v1, v2 = G[root]
        iroot = G2i[v1][root]
        cumleft[v1][iroot+1] = cumright[v1][iroot] = f(res2[root], res[v2])
        iroot = G2i[v2][root]
        cumleft[v2][iroot+1] = cumright[v2][iroot] = f(res2[root], res[v1])
        for i in range(len(G[root])):
            cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i])
            cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1])
    else:
        for i in range(len(G[root])):
            cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i])
            cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1])
        for v in G[root]:
            iroot = G2i[v][root]
            iv = G2i[root][v]
            cumleft[v][iroot+1] = cumright[v][iroot] = merge(res2[v], [cumleft[root][iv], cumright[root][-iv-1]])
    
    # 根から近い順に更新
    for v in order[1:]:
        par = prev[v]
        ipar = G2i[v][par]
        iv = G2i[par][v]
        if len(G[par])==1:
            res_par = f0(_, _)
        else:
            res_par = merge(_, [cumleft[par][iv], cumright[par][iv+1]])
        cumleft[v][ipar+1] = cumright[v][ipar] = res_par
        if len(G[v])==1:
            res2[v] = f_last(res2[v], [res_par])
        elif len(G[v])==2:
            for v2 in G[v]:
                if v2 != par: break
            res2[v] = f_last(res2[v], [res_par, res[v2]])
            iv2 = G2i[v][v2]
            cumleft[v][iv2+1] = cumright[v][iv2] = res[v2]
            for i in range(len(G[v])):
                cumleft[v][i+1] = g(cumleft[v][i+1], cumleft[v][i])
                cumright[v][-i-2] = g(cumright[v][-i-2], cumright[v][-i-1])
        else:
            for v2 in G[v]:
                if v2 != par:
                    iv2 = G2i[v][v2]
                    cumleft[v][iv2+1] = cumright[v][iv2] = res[v2]
            res2[v] = f_last(res2[v], [cumleft[v][i+1] for i in range(len(G[v]))])
            for i in range(len(G[v])):
                cumleft[v][i+1] = g(cumleft[v][i+1], cumleft[v][i])
                cumright[v][-i-2] = g(cumright[v][-i-2], cumright[v][-i-1])

    return res2

バージョン2:左右累積バージョン(微修正版)

バージョン1とほぼ同じだが、一度木DPを終えて他のノードで再計算を行うとき、累積値をマージする関数merge2を定義可能にしたもの。EDPC-Vを解くために使用。

引数の設定例

EDPC-V

atcoder.jp

def f0(a, b):
    return e
def f(a, b):
    return b+1
def merge(a, res):
    out = 1
    for r in res:
        out *= r+1
        out %= M
    return out
  
def merge2(a,res):
  return res[0]*res[1]%M

f_last = merge

e = 1

def g(a, b):
    return (a+1)*b%M

ソースコード

def rerooting(root, f0, f, merge, f_last, e, g, merge2):
    N = len(G)
    # 左右から累積するときにインデックスへの変換用
    G2i = [{} for i in range(N)]
    for i in range(N):
        G[i].sort()
        G2i[i] = {v:j for j,v in enumerate(G[i])}
    res, order, d, prev = tree_dp(root, f, f0, merge, f_last, e)
    # 左右からの累積用の配列
    cumleft = [[e for _ in range(len(G[i])+1)] for i in range(N)]
    cumright = [[e for _ in range(len(G[i])+1)] for i in range(N)]
    res2 = [e]*N
    # 根の結果は同じ
    res2[root] = res[root]
    
    # 根の累積
    for v in G[root]:
        iv = G2i[root][v]
        cumleft[root][iv+1] = cumright[root][iv] = res[v]
    
    # 根の子ノードを先に計算
    if len(G[root])==1:
        v = G[root][0]
        iroot = G2i[v][root]
        cumleft[v][iroot+1] = cumright[v][iroot] = f0(_, _)
        for i in range(len(G[root])):
            cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i])
            cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1])
    elif len(G[root])==2:
        v1, v2 = G[root]
        iroot = G2i[v1][root]
        cumleft[v1][iroot+1] = cumright[v1][iroot] = f(res2[root], res[v2])
        iroot = G2i[v2][root]
        cumleft[v2][iroot+1] = cumright[v2][iroot] = f(res2[root], res[v1])
        for i in range(len(G[root])):
            cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i])
            cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1])
    else:
        for i in range(len(G[root])):
            cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i])
            cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1])
        for v in G[root]:
            iroot = G2i[v][root]
            iv = G2i[root][v]
            cumleft[v][iroot+1] = cumright[v][iroot] = merge2(res2[v], [cumleft[root][iv], cumright[root][-iv-1]])
    
    # 根から近い順に更新
    for v in order[1:]:
        par = prev[v]
        ipar = G2i[v][par]
        iv = G2i[par][v]
        if len(G[par])==1:
            res_par = f0(_, _)
        else:
            res_par = merge2(_, [cumleft[par][iv], cumright[par][iv+1]])
        cumleft[v][ipar+1] = cumright[v][ipar] = res_par
        if len(G[v])==1:
            res2[v] = f_last(res2[v], [res_par])
        elif len(G[v])==2:
            for v2 in G[v]:
                if v2 != par: break
            res2[v] = f_last(res2[v], [res_par, res[v2]])
            iv2 = G2i[v][v2]
            cumleft[v][iv2+1] = cumright[v][iv2] = res[v2]
            for i in range(len(G[v])):
                cumleft[v][i+1] = g(cumleft[v][i+1], cumleft[v][i])
                cumright[v][-i-2] = g(cumright[v][-i-2], cumright[v][-i-1])
        else:
            for v2 in G[v]:
                if v2 != par:
                    iv2 = G2i[v][v2]
                    cumleft[v][iv2+1] = cumright[v][iv2] = res[v2]
            res2[v] = f_last(res2[v], [cumleft[v][i+1] for i in range(len(G[v]))])
            for i in range(len(G[v])):
                cumleft[v][i+1] = g(cumleft[v][i+1], cumleft[v][i])
                cumright[v][-i-2] = g(cumright[v][-i-2], cumright[v][-i-1])
    return res2

バージョン3:逆演算が定義できる場合に、全体の累積値から一部の累積値を計算するバージョン

引数に入れる関数の説明

  • f0(a, b): 次数が1のノード(葉)の演算。
  • f(a,b) -> c: 次数が2のノードiの演算。aはノードi自身のdp配列、bは子ノードのdp配列。
  • merge(a,b) -> c: 字数が3以上のノードiの演算。aはノードi自身のdp配列、bは子ノードのdp配列。
  • f_last(a,b): 根で行う最後の演算。
  • e: 単位元
  • prod: 累積をとるときの演算。二項演算が+ならsum, maxならmax, *ならprod
  • inv: 累積から引くときの演算。二項演算の逆演算。+なら-, *なら/もしくはmodinv
  • process: 累積から最後のノードを加えるときの演算。ほぼ同じか、少し変わるだけ。

mergeとf_last, mergeとprodは同じ場合が多いが一応分けている。

引数の設定例

ABC 160 F

atcoder.jp

# mod逆元を計算する関数
def modinv(x):
    return pow(x, mod-2, mod)

# 階乗の前計算
fac=[1 for i in range(2*N+10)]
for i in range(2,2*N+10):
    fac[i] = fac[i-1]*i % mod

# a, bはdp[v]を表す
# dp[v] = [t,val]: t: 部分木のサイズ、val: 求めたい値
def f0(a, b):
    return e[0]+1,e[1]

def f(a, b):
    return b[0]+1,b[1]

# resはdp配列のリスト[dp[v1], dp[v2], ...]
def merge(a, res):
    treesize = 0
    denom = 1
    val = 1
    for t,v in res:
        treesize += t
        denom *= fac[t]
        denom %= mod
        val *= v
        val %= mod
    return treesize + 1, fac[treesize]*modinv(denom)*val%mod

f_last = merge

e = (0,1)

def prod(res):
    treesize = 0
    denom = 1
    val = 1
    for t,v in res:
        treesize += t
        denom *= fac[t]
        denom %= mod
        val *= v
        val %= mod
    return treesize, fac[treesize]*modinv(denom)*val%mod
  
def inv(a,b):
  t0,v0 = a
  t1,v1 = b
  nom = fac[t1]*fac[t0-t1]%mod
  denom = fac[t0]*v1%mod
  out = t0-t1,v0*(nom*modinv(denom)%mod)%mod
  return out

def process(a):
  return a[0]+1,a[1]

ソースコード

def bfs(root=0):
    N = len(G)
    visited = [0]*N
    visited[root] = 1
    q = [root]
    d = [0]*N
    prev = [-1]*N
    for v in q:
        for v2 in G[v]:
            if visited[v2]: continue
            visited[v2] = 1
            d[v2] = d[v] + 1
            q.append(v2)
            prev[v2] = v
    return d, prev

def tree_dp(root, f, f0, merge, f_last, e):
    N = len(G)
    d, prev = bfs(root)
    order, _ = zip(*sorted(enumerate(d),key=lambda x:x[1]))
    res = [e for _ in range(N)]
    for v in reversed(order[1:]):
        if len(G[v])==1:
            res[v] = f0(res[v], e)
        elif len(G[v])==2:
            for v2 in G[v]:
                if v2 != prev[v]: break
            res[v] = f(res[v], res[v2])
        else:
            res_tmp = [res[v2] for v2 in G[v] if v2 != prev[v]]
            res[v] = merge(res[v], res_tmp)
    res_tmp = [res[v2] for v2 in G[root]]
    res[root] = f_last(res[root], res_tmp)
    return res, order, d, prev

def rerooting2(root, f0, f, merge, f_last, e, prod, inv, process):
    N = len(G)
    res, order, d, prev = tree_dp(root, f, f0, merge, f_last, e)
    # 左右からの累積用の配列
    SUM = [e for _ in range(N)]
    res2 = [e]*N
    # 根の結果は同じ
    res2[root] = res[root]
    
    # 根の累積
    SUM[root] = prod([res[v] for v in G[root]])
    
    # 根から近い順に更新
    for v in order[1:]:
        subtracted = inv(SUM[prev[v]], res[v])
        if len(G[prev[v]])==1:
            res_par = f0(_, _)
        else:
            res_par = process(subtracted)
        SUM[v] = prod([res[v2] for v2 in G[v] if v2 != prev[v]]+[res_par])
        res2[v] = process(SUM[v])

    return res2

【Python】バケット法(平方分割)

機能

セグ木と同様に、一点更新と区間取得ができます。ただし計算量はセグ木がの区間取得がO(logN)のところがバケット法だとO(\sqrt{N})かかり、遅くなります。基本的にセグ木と同じことしかできないにも関わらず計算量の点では性能は一段階劣るという感じなので出番はあまりないですが、セグ木に比べて構造が単純な分、問題に合わせて中身をいじったり、デバッグしたりはしやすいのではないかと考えています(今まで一度もそうした機会はないですが)。
また、演算の順番を崩さないので非可換モノイドにも対応可です。

使い方

  • インスタンスの作成:基本的にはセグ木と同じで、配列の初期値(あれば)と二項演算と単位元を引数として渡します。逆元がある演算の場合は逆元をinvとして指定することで少し計算が早くなります。引数sizeを使ってバケットサイズを指定することもできます(バケットサイズを指定しない場合、Nの平方根に近い値がデフォルトで割り当てられます)
    • bk = Bucket(N,data=A,f=lambda a,b:a+b,e=0)
    • (逆元がある場合)bk = Bucket(N,data=A,f=lambda a,b:a+b,e=0, inv=lambda a,b:a-b)
    • (バケットサイズを指定する場合)bk = Bucket(N,data=A,f=lambda a,b:a+b,e=0, inv=lambda a,b:a-b, size=100)
  • 一点更新(iは0-indexed)
    • bk.set_val(i, x)
  • 区間取得(0-indexedかつ半開区間。range(l,r)と同じ)
    • bk.prod(l,r)
  • 区間取得:始点から終点までの全区間の計算結果が欲しいときはbk.prod(0,N)とするよりこちらの方が速い
    • bk.prod_all()

使用例

-セグ木の1.5倍程度の時間でAC
- https://atcoder.jp/contests/practice2/submissions/22311944
- 逆元が定義されている場合は更新作業が一回の演算で済むので早くなる。
- (+の逆元は-) https://atcoder.jp/contests/practice2/submissions/22312223
- (xorの逆元もxor) https://atcoder.jp/contests/abc185/submissions/22312514
- 非可換モノイドにも対応(タコヤキオイシクナール)
- https://atcoder.jp/contests/arc008/submissions/22313073
- ライブラリチェッカーでは一部TLE
- https://judge.yosupo.jp/submission/61631

参考にした記事など

www.slideshare.net

最初に自力実装したデータ構造なので思い入れがあります。

class Bucket:
    def __init__(self, N, data=None, f=max, e=0, inv=None, size=None):
        self.N = N
        if not data: self.data = [e for _ in range(N)]
        else: self.data = data
        if not size: size = int(N**0.5)
        self.size, self.n = size,(N+size-1)//size
        self.f,  self.e, self.inv = f, e, inv
        self.middle = [e for _ in range(self.n)]
        if data:
            for i in range(self.n):
                for j in range(i*size, min((i+1)*size, N)): self.middle[i] = self.f(self.middle[i], self.data[j])
        self.all = e
        if data:
            for i in range(self.n): self.all = self.f(self.all, self.middle[i])
                
    def set_val(self, i, x):
        # 逆元が定義されている場合は一回の演算で済ませる。
        if self.inv:
            self.middle[i//self.size] = self.inv(self.middle[i//self.size], self.data[i])
            self.middle[i//self.size] = self.f(self.middle[i//self.size], x)
            self.all = self.inv(self.all, self.data[i])
            self.all = self.f(self.all, x)
            self.data[i] = x
            return
        self.data[i] = x
        ib = i//self.size
        self.middle[ib] = self.e
        for j in range(ib*self.size, min((ib+1)*self.size, N)): self.middle[ib] = self.f(self.middle[ib], self.data[j])
        self.all = self.e
        for i in range(self.n): self.all = self.f(self.all, self.middle[i])
            
    # [L, R)
    def prod(self, l, r):
        ibl, ibr = l//self.size, (r-1)//self.size
        out = self.e
        if ibl==ibr:
            for i in range(l, r): out = self.f(out, self.data[i])
            return out
        for i in range(l, (ibl+1)*self.size): out = self.f(out, self.data[i])
        for i in range(ibl+1, ibr): out = self.f(out, self.middle[i])
        for i in range(ibr*self.size, r): out = self.f(out, self.data[i])
        return out

    def prod_all(self): return self.all

【Python】非再帰抽象化 Segment Tree (通常バージョン、非可換対応バージョン)

バージョン1:可換モノイド用

機能

通常のセグメント木と同様、一要素ごとの更新、区間における演算結果を取得できる。抽象化してあり、二演算と単位元のペアを指定することで様々な可換モノイドに対応できる。

(二項演算と単位元の例)

二項演算 単位元
max 0 (負の数を含む場合は-INF)
min INF
+ 0
* 1
+ (mod) 0
* (mod) 1
xor 0
gcd 0
lcm 1

使い方

  • インスタンスの作成

    • (例)maxを計算するセグ木を作りたい場合:seg = SegTree(N=N, data=arr, f=max, e=0)
    • (例)* (mod) を計算するセグ木を作りたい場合:seg = SegTree(N=N, data=arr, f=lambda a,b: a * b % mod, e=1)
    • (例) xorを計算するセグ木を作りたい場合:seg = SegTree(N=N, data=arr, f=lambda a,b: a ^ b, e=0)
  • インデックスi (0-indexed)の要素の値をxに変更

    • seg.set_val(i, x)
  • 区間[l, r)(0-indexedかつ半開区間。つまりpythonのrange(l, r)と同じ)の計算結果を出力。

    • seg.prod(l, r)

使用例

atcoder.jp

参考にした記事など

tsutaj.hatenablog.com

class SegTree:
    def __init__(self, N, data=None, f=max, e=0):
        self.N = N
        if not data:
            data = [e for _ in range(N)]
        self.data = [e for _ in range(N-1)]+data
        self.f = f
        self.e = e
        for i in range(2*N-3, 0, -2):
            self.data[i>>1] = self.f(self.data[i],self.data[i+1])

    def set_val(self, i, x):
        i += self.N-1
        self.data[i] = x
        while i:
            if i%2==0: i-= 1
            self.data[i>>1] = self.f(self.data[i],self.data[i+1])
            i  >>= 1
            
    # [L, R)
    def prod(self, l, r):
        if r-l==1: return self.data[l+self.N-1]
        l += self.N-1
        r+=self.N-2
        out = self.e
        while r > l:
            if l%2==0:
                out = self.f(out, self.data[l])
                l += 1
            if r%2:
                out = self.f(out, self.data[r])
                r -= 1
            if r-l==1:
                out = self.f(out, self.data[l>>1])
                break
            elif r==l:
                out = self.f(out, self.data[l])
                break
            l -= 1
            l >>= 1
            r -= 1
            r >>= 1
        return out

バージョン2: 非可換モノイド対応版

機能

可換モノイドに加えて、非可換モノイドにも対応(つまり、配列の前後の順番を変えずに計算することができる)したバージョン。順番を前後させると計算結果が変わってしまう二項演算の例として、-, / (割り算), ax+b (アフィン変換), 行列の積などがある。

使い方

バージョン1と同じ。

使用例

class SegTree:
    def __init__(self, N, data=None, f=max, e=0):
        self.N = N
        if not data:
            data = [e for _ in range(N)]
        self.data = [e for _ in range(N-1)]+data
        self.f = f
        self.e = e
        for i in range(2*N-3, 0, -2):
            self.data[i>>1] = self.f(self.data[i],self.data[i+1])

    def set_val(self, i, x):
        i += self.N-1
        self.data[i] = x
        while i:
            if i%2==0: i-= 1
            self.data[i>>1] = self.f(self.data[i],self.data[i+1])
            i  >>= 1
            
    # [L, R)
    # 下から上げていって、要求区間が対象区間を完全に含んでいれば区間のインデックスをスタックに突っ込む。
    # 上に上がるにつれて、徐々に範囲を狭めていく
    def prod(self, l, r):
        if r-l==1: return self.data[l+self.N-1]
        l += self.N-1
        r+=self.N-2
        left,right=[],[]
        while r > l:
            if l%2==0:
                left.append(l)
                l += 1
            if r%2:
                right.append(r)
                r -= 1
            if r-l==1:
                left.append(l>>1)
                break
            elif r==l:
                left.append(l)
                break
            l -= 1
            l >>= 1
            r -= 1
            r >>= 1
        res = left+right[::-1]
        out = self.data[res[0]]
        for i in range(1,len(res)):
            out = self.f(out, self.data[res[i]])
        return out

【Python】削除可能優先度付きキュー

機能

通常の優先度付きキューの機能
- 要素の追加:O(logN)
- 最小値(最大値)の出力または削除:O(logN)
に加えて、
- 特定の要素の削除:O(logN)
をすることができます。

使い方

  • キューのインスタンス作成
    • 最小値の取得をするキュー:q = DeletablePQ()(入れたいデータがある場合はq=DeletablePQ(data=arr))
    • 最大値の取得をするキュー:q = DeletablePQ(greater=True)(入れたいデータがある場合はq=DeletablePQ(data=arr,greater=True))
  • 要素xの追加
    • q.push(x)
  • 要素xの削除
    • q.remove(x)
  • 要素xをn個削除
    • q.nremove(x, num=n)
  • 最小値(最大値)を削除返り値あり)
    • x = q.pop()
  • xの個数を出力
    • q.count(x)
  • 最小値(最大値)を出力(削除はせず、参照するだけ
    • x = q.get_val()
  • 要素の総数を出力
    • len(q)
  • 他のDeletablePQのインスタンスとマージ(要素数の多いインスタンスをベースに行うとより速いです)
    • q.merge(other_q)

使用例

from heapq import heappush,heappop
 
class DeletablePQ:
  def __init__(self,data=None,greater=False):
    self.greater=greater
    self.count= {}
    self.que=[]
    if data:
        for x in data:
            heappush(self.que,x*(-1 if greater else 1))
            if self.count.get(x): self.count[x] += 1
            else: self.count[x] = 1
    self.len = len(self.que)
 
  def push(self,x):
    heappush(self.que,x*(-1 if self.greater else 1))
    if self.count.get(x): self.count[x] += 1
    else: self.count[x] = 1
    self.len += 1
    
  def remove(self,x):
    self.count[x] -= 1
    self.len -= 1
  
  def nremove(self,x,num=1):
    self.count[x] -= num
    self.len -= num
    if self.count[x] < 0:
      self.len -= self.count[x]
      self.count[x] = 0
    
  def pop(self):
    x,self.count[x]=-1,0
    while self.count[x]<1: x=heappop(self.que)*(-1 if self.greater else 1)
    self.len -= 1
    self.count[x] -= 1
    return x
  
  def count(self,x):
    return self.count[x] if self.count[x]>0 else 0
  
  def get_val(self):
    while self.count.get(self.que[0]*(-1 if self.greater else 1)) != None and self.count[self.que[0]*(-1 if self.greater else 1)]<1:
        heappop(self.que)
    return self.que[0]*(-1 if self.greater else 1)
 
  def merge(self, other):
        while len(other) > 0: self.push(other.pop())
  
  def __len__(self):
    return self.len