【Python】全方位木DP (書きかけ)

【注意】この記事は書きかけです。

抽象化全方位木DPのライブラリです。全3バージョンあります。多数の引数が入り乱れている実装で改善の余地がありますが、問題をACできる程度には動きます。

実装にあたっては、すぬけさんによるABCの解説放送(https://www.youtube.com/watch?v=zG1L4vYuGrg)のほか、下記の記事が大変参考になりました。
- https://qiita.com/Kiri8128/items/a011c90d25911bdb3ed3
- https://ei1333.hateblo.jp/entry/2017/04/10/224413
- https://algo-logic.info/tree-dp/
- https://qiita.com/keymoon/items/2a52f1b0fb7ef67fb89e

バージョン1:左右累積バージョン

引数に入れる関数の説明

  • f0(a, b): 次数が1のノード(葉)の演算。
  • f(a,b) -> c: 次数が2のノードiの演算。aはノードi自身のdp配列、bは子ノードのdp配列。
  • merge(a,b) -> c: 字数が3以上のノードiの演算。aはノードi自身のdp配列、bは子ノードのdp配列をまとめたリスト。
  • f_last(a,b): 根で行う最後の演算。
  • e: mergeの際の単位元
  • g(a,b) -> c: 累積値を使ってマージしていく時の二項演算

引数の設定例

典型003:各ノードi=0,...,N-1からの最大の距離max(dist(i, j))

atcoder.jp

def f0(a, b):
    return e

def f(a, b):
    return b+1

def merge(a, res):
    return max(res)+1

f_last = merge

e = 0

def g(a, b):
    return max(a, b)

ABC220F(他のノードの距離の総和sum(dist(i,j)))

atcoder.jp

def f0(a, b):
    return (1,0)

def f(a, b):
    sub0, val0 = b
    return (sub0+1,sub0+val0)

def merge(a, res):
    out = [1,0]
    for sub,val in res:
        out[0] += sub
        out[1] += sub + val
    return out

f_last = merge

e = (0,0)

def g(a, b):
    sub1,val1 = a
    sub2,val2 = b
    return sub1+sub2,val1+val2

ソースコード

# 全方位木DPの中で使うための関数
def bfs(root=0):
    N = len(G)
    visited = [0]*N
    visited[root] = 1
    q = [root]
    d = [0]*N
    prev = [-1]*N
    for v in q:
        for v2 in G[v]:
            if visited[v2]: continue
            visited[v2] = 1
            d[v2] = d[v] + 1
            q.append(v2)
            prev[v2] = v
    return d, prev

def tree_dp(root, f, f0, merge, f_last, e):
    N = len(G)
    d, prev = bfs(root)
    order, _ = zip(*sorted(enumerate(d),key=lambda x:x[1]))
    res = [e for _ in range(N)]
    for v in reversed(order[1:]):
        if len(G[v])==1:
            res[v] = f0(res[v], e)
        elif len(G[v])==2:
            for v2 in G[v]:
                if v2 != prev[v]: break
            res[v] = f(res[v], res[v2])
        else:
            res_tmp = [res[v2] for v2 in G[v] if v2 != prev[v]]
            res[v] = merge(res[v], res_tmp)
    res_tmp = [res[v2] for v2 in G[root]]
    res[root] = f_last(res[root], res_tmp)
    return res, order, d, prev

# 全方位木DP本体
def rerooting(root, f0, f, merge, f_last, e, g):
    N = len(G)
    # 左右から累積するときにインデックスへの変換用
    G2i = [{} for i in range(N)]
    for i in range(N):
        G[i].sort()
        G2i[i] = {v:j for j,v in enumerate(G[i])}
    res, order, d, prev = tree_dp(root, f, f0, merge, f_last, e)
    # 左右からの累積用の配列
    cumleft = [[e for _ in range(len(G[i])+1)] for i in range(N)]
    cumright = [[e for _ in range(len(G[i])+1)] for i in range(N)]
    res2 = [e]*N
    # 根の結果は同じ
    res2[root] = res[root]
    
    # 根の累積
    for v in G[root]:
        iv = G2i[root][v]
        cumleft[root][iv+1] = cumright[root][iv] = res[v]
    
    # 根の子ノードを先に計算
    if len(G[root])==1:
        v = G[root][0]
        iroot = G2i[v][root]
        cumleft[v][iroot+1] = cumright[v][iroot] = f0(_, _)
        for i in range(len(G[root])):
            cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i])
            cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1])
    elif len(G[root])==2:
        v1, v2 = G[root]
        iroot = G2i[v1][root]
        cumleft[v1][iroot+1] = cumright[v1][iroot] = f(res2[root], res[v2])
        iroot = G2i[v2][root]
        cumleft[v2][iroot+1] = cumright[v2][iroot] = f(res2[root], res[v1])
        for i in range(len(G[root])):
            cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i])
            cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1])
    else:
        for i in range(len(G[root])):
            cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i])
            cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1])
        for v in G[root]:
            iroot = G2i[v][root]
            iv = G2i[root][v]
            cumleft[v][iroot+1] = cumright[v][iroot] = merge(res2[v], [cumleft[root][iv], cumright[root][-iv-1]])
    
    # 根から近い順に更新
    for v in order[1:]:
        par = prev[v]
        ipar = G2i[v][par]
        iv = G2i[par][v]
        if len(G[par])==1:
            res_par = f0(_, _)
        else:
            res_par = merge(_, [cumleft[par][iv], cumright[par][iv+1]])
        cumleft[v][ipar+1] = cumright[v][ipar] = res_par
        if len(G[v])==1:
            res2[v] = f_last(res2[v], [res_par])
        elif len(G[v])==2:
            for v2 in G[v]:
                if v2 != par: break
            res2[v] = f_last(res2[v], [res_par, res[v2]])
            iv2 = G2i[v][v2]
            cumleft[v][iv2+1] = cumright[v][iv2] = res[v2]
            for i in range(len(G[v])):
                cumleft[v][i+1] = g(cumleft[v][i+1], cumleft[v][i])
                cumright[v][-i-2] = g(cumright[v][-i-2], cumright[v][-i-1])
        else:
            for v2 in G[v]:
                if v2 != par:
                    iv2 = G2i[v][v2]
                    cumleft[v][iv2+1] = cumright[v][iv2] = res[v2]
            res2[v] = f_last(res2[v], [cumleft[v][i+1] for i in range(len(G[v]))])
            for i in range(len(G[v])):
                cumleft[v][i+1] = g(cumleft[v][i+1], cumleft[v][i])
                cumright[v][-i-2] = g(cumright[v][-i-2], cumright[v][-i-1])

    return res2

バージョン2:左右累積バージョン(微修正版)

バージョン1とほぼ同じだが、一度木DPを終えて他のノードで再計算を行うとき、累積値をマージする関数merge2を定義可能にしたもの。EDPC-Vを解くために使用。

引数の設定例

EDPC-V

atcoder.jp

def f0(a, b):
    return e
def f(a, b):
    return b+1
def merge(a, res):
    out = 1
    for r in res:
        out *= r+1
        out %= M
    return out
  
def merge2(a,res):
  return res[0]*res[1]%M

f_last = merge

e = 1

def g(a, b):
    return (a+1)*b%M

ソースコード

def rerooting(root, f0, f, merge, f_last, e, g, merge2):
    N = len(G)
    # 左右から累積するときにインデックスへの変換用
    G2i = [{} for i in range(N)]
    for i in range(N):
        G[i].sort()
        G2i[i] = {v:j for j,v in enumerate(G[i])}
    res, order, d, prev = tree_dp(root, f, f0, merge, f_last, e)
    # 左右からの累積用の配列
    cumleft = [[e for _ in range(len(G[i])+1)] for i in range(N)]
    cumright = [[e for _ in range(len(G[i])+1)] for i in range(N)]
    res2 = [e]*N
    # 根の結果は同じ
    res2[root] = res[root]
    
    # 根の累積
    for v in G[root]:
        iv = G2i[root][v]
        cumleft[root][iv+1] = cumright[root][iv] = res[v]
    
    # 根の子ノードを先に計算
    if len(G[root])==1:
        v = G[root][0]
        iroot = G2i[v][root]
        cumleft[v][iroot+1] = cumright[v][iroot] = f0(_, _)
        for i in range(len(G[root])):
            cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i])
            cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1])
    elif len(G[root])==2:
        v1, v2 = G[root]
        iroot = G2i[v1][root]
        cumleft[v1][iroot+1] = cumright[v1][iroot] = f(res2[root], res[v2])
        iroot = G2i[v2][root]
        cumleft[v2][iroot+1] = cumright[v2][iroot] = f(res2[root], res[v1])
        for i in range(len(G[root])):
            cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i])
            cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1])
    else:
        for i in range(len(G[root])):
            cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i])
            cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1])
        for v in G[root]:
            iroot = G2i[v][root]
            iv = G2i[root][v]
            cumleft[v][iroot+1] = cumright[v][iroot] = merge2(res2[v], [cumleft[root][iv], cumright[root][-iv-1]])
    
    # 根から近い順に更新
    for v in order[1:]:
        par = prev[v]
        ipar = G2i[v][par]
        iv = G2i[par][v]
        if len(G[par])==1:
            res_par = f0(_, _)
        else:
            res_par = merge2(_, [cumleft[par][iv], cumright[par][iv+1]])
        cumleft[v][ipar+1] = cumright[v][ipar] = res_par
        if len(G[v])==1:
            res2[v] = f_last(res2[v], [res_par])
        elif len(G[v])==2:
            for v2 in G[v]:
                if v2 != par: break
            res2[v] = f_last(res2[v], [res_par, res[v2]])
            iv2 = G2i[v][v2]
            cumleft[v][iv2+1] = cumright[v][iv2] = res[v2]
            for i in range(len(G[v])):
                cumleft[v][i+1] = g(cumleft[v][i+1], cumleft[v][i])
                cumright[v][-i-2] = g(cumright[v][-i-2], cumright[v][-i-1])
        else:
            for v2 in G[v]:
                if v2 != par:
                    iv2 = G2i[v][v2]
                    cumleft[v][iv2+1] = cumright[v][iv2] = res[v2]
            res2[v] = f_last(res2[v], [cumleft[v][i+1] for i in range(len(G[v]))])
            for i in range(len(G[v])):
                cumleft[v][i+1] = g(cumleft[v][i+1], cumleft[v][i])
                cumright[v][-i-2] = g(cumright[v][-i-2], cumright[v][-i-1])
    return res2

バージョン3:逆演算が定義できる場合に、全体の累積値から一部の累積値を計算するバージョン

引数に入れる関数の説明

  • f0(a, b): 次数が1のノード(葉)の演算。
  • f(a,b) -> c: 次数が2のノードiの演算。aはノードi自身のdp配列、bは子ノードのdp配列。
  • merge(a,b) -> c: 字数が3以上のノードiの演算。aはノードi自身のdp配列、bは子ノードのdp配列。
  • f_last(a,b): 根で行う最後の演算。
  • e: 単位元
  • prod: 累積をとるときの演算。二項演算が+ならsum, maxならmax, *ならprod
  • inv: 累積から引くときの演算。二項演算の逆演算。+なら-, *なら/もしくはmodinv
  • process: 累積から最後のノードを加えるときの演算。ほぼ同じか、少し変わるだけ。

mergeとf_last, mergeとprodは同じ場合が多いが一応分けている。

引数の設定例

ABC 160 F

atcoder.jp

# mod逆元を計算する関数
def modinv(x):
    return pow(x, mod-2, mod)

# 階乗の前計算
fac=[1 for i in range(2*N+10)]
for i in range(2,2*N+10):
    fac[i] = fac[i-1]*i % mod

# a, bはdp[v]を表す
# dp[v] = [t,val]: t: 部分木のサイズ、val: 求めたい値
def f0(a, b):
    return e[0]+1,e[1]

def f(a, b):
    return b[0]+1,b[1]

# resはdp配列のリスト[dp[v1], dp[v2], ...]
def merge(a, res):
    treesize = 0
    denom = 1
    val = 1
    for t,v in res:
        treesize += t
        denom *= fac[t]
        denom %= mod
        val *= v
        val %= mod
    return treesize + 1, fac[treesize]*modinv(denom)*val%mod

f_last = merge

e = (0,1)

def prod(res):
    treesize = 0
    denom = 1
    val = 1
    for t,v in res:
        treesize += t
        denom *= fac[t]
        denom %= mod
        val *= v
        val %= mod
    return treesize, fac[treesize]*modinv(denom)*val%mod
  
def inv(a,b):
  t0,v0 = a
  t1,v1 = b
  nom = fac[t1]*fac[t0-t1]%mod
  denom = fac[t0]*v1%mod
  out = t0-t1,v0*(nom*modinv(denom)%mod)%mod
  return out

def process(a):
  return a[0]+1,a[1]

ソースコード

def bfs(root=0):
    N = len(G)
    visited = [0]*N
    visited[root] = 1
    q = [root]
    d = [0]*N
    prev = [-1]*N
    for v in q:
        for v2 in G[v]:
            if visited[v2]: continue
            visited[v2] = 1
            d[v2] = d[v] + 1
            q.append(v2)
            prev[v2] = v
    return d, prev

def tree_dp(root, f, f0, merge, f_last, e):
    N = len(G)
    d, prev = bfs(root)
    order, _ = zip(*sorted(enumerate(d),key=lambda x:x[1]))
    res = [e for _ in range(N)]
    for v in reversed(order[1:]):
        if len(G[v])==1:
            res[v] = f0(res[v], e)
        elif len(G[v])==2:
            for v2 in G[v]:
                if v2 != prev[v]: break
            res[v] = f(res[v], res[v2])
        else:
            res_tmp = [res[v2] for v2 in G[v] if v2 != prev[v]]
            res[v] = merge(res[v], res_tmp)
    res_tmp = [res[v2] for v2 in G[root]]
    res[root] = f_last(res[root], res_tmp)
    return res, order, d, prev

def rerooting2(root, f0, f, merge, f_last, e, prod, inv, process):
    N = len(G)
    res, order, d, prev = tree_dp(root, f, f0, merge, f_last, e)
    # 左右からの累積用の配列
    SUM = [e for _ in range(N)]
    res2 = [e]*N
    # 根の結果は同じ
    res2[root] = res[root]
    
    # 根の累積
    SUM[root] = prod([res[v] for v in G[root]])
    
    # 根から近い順に更新
    for v in order[1:]:
        subtracted = inv(SUM[prev[v]], res[v])
        if len(G[prev[v]])==1:
            res_par = f0(_, _)
        else:
            res_par = process(subtracted)
        SUM[v] = prod([res[v2] for v2 in G[v] if v2 != prev[v]]+[res_par])
        res2[v] = process(SUM[v])

    return res2