【Python】全方位木DP (書きかけ)
【注意】この記事は書きかけです。
抽象化全方位木DPのライブラリです。全3バージョンあります。多数の引数が入り乱れている実装で改善の余地がありますが、問題をACできる程度には動きます。
実装にあたっては、すぬけさんによるABCの解説放送(https://www.youtube.com/watch?v=zG1L4vYuGrg)のほか、下記の記事が大変参考になりました。
- https://qiita.com/Kiri8128/items/a011c90d25911bdb3ed3
- https://ei1333.hateblo.jp/entry/2017/04/10/224413
- https://algo-logic.info/tree-dp/
- https://qiita.com/keymoon/items/2a52f1b0fb7ef67fb89e
バージョン1:左右累積バージョン
引数に入れる関数の説明
- f0(a, b): 次数が1のノード(葉)の演算。
- f(a,b) -> c: 次数が2のノードiの演算。aはノードi自身のdp配列、bは子ノードのdp配列。
- merge(a,b) -> c: 字数が3以上のノードiの演算。aはノードi自身のdp配列、bは子ノードのdp配列をまとめたリスト。
- f_last(a,b): 根で行う最後の演算。
- e: mergeの際の単位元
- g(a,b) -> c: 累積値を使ってマージしていく時の二項演算
引数の設定例
典型003:各ノードi=0,...,N-1からの最大の距離max(dist(i, j))
def f0(a, b): return e def f(a, b): return b+1 def merge(a, res): return max(res)+1 f_last = merge e = 0 def g(a, b): return max(a, b)
ABC220F(他のノードの距離の総和sum(dist(i,j)))
def f0(a, b): return (1,0) def f(a, b): sub0, val0 = b return (sub0+1,sub0+val0) def merge(a, res): out = [1,0] for sub,val in res: out[0] += sub out[1] += sub + val return out f_last = merge e = (0,0) def g(a, b): sub1,val1 = a sub2,val2 = b return sub1+sub2,val1+val2
ソースコード
# 全方位木DPの中で使うための関数 def bfs(root=0): N = len(G) visited = [0]*N visited[root] = 1 q = [root] d = [0]*N prev = [-1]*N for v in q: for v2 in G[v]: if visited[v2]: continue visited[v2] = 1 d[v2] = d[v] + 1 q.append(v2) prev[v2] = v return d, prev def tree_dp(root, f, f0, merge, f_last, e): N = len(G) d, prev = bfs(root) order, _ = zip(*sorted(enumerate(d),key=lambda x:x[1])) res = [e for _ in range(N)] for v in reversed(order[1:]): if len(G[v])==1: res[v] = f0(res[v], e) elif len(G[v])==2: for v2 in G[v]: if v2 != prev[v]: break res[v] = f(res[v], res[v2]) else: res_tmp = [res[v2] for v2 in G[v] if v2 != prev[v]] res[v] = merge(res[v], res_tmp) res_tmp = [res[v2] for v2 in G[root]] res[root] = f_last(res[root], res_tmp) return res, order, d, prev # 全方位木DP本体 def rerooting(root, f0, f, merge, f_last, e, g): N = len(G) # 左右から累積するときにインデックスへの変換用 G2i = [{} for i in range(N)] for i in range(N): G[i].sort() G2i[i] = {v:j for j,v in enumerate(G[i])} res, order, d, prev = tree_dp(root, f, f0, merge, f_last, e) # 左右からの累積用の配列 cumleft = [[e for _ in range(len(G[i])+1)] for i in range(N)] cumright = [[e for _ in range(len(G[i])+1)] for i in range(N)] res2 = [e]*N # 根の結果は同じ res2[root] = res[root] # 根の累積 for v in G[root]: iv = G2i[root][v] cumleft[root][iv+1] = cumright[root][iv] = res[v] # 根の子ノードを先に計算 if len(G[root])==1: v = G[root][0] iroot = G2i[v][root] cumleft[v][iroot+1] = cumright[v][iroot] = f0(_, _) for i in range(len(G[root])): cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i]) cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1]) elif len(G[root])==2: v1, v2 = G[root] iroot = G2i[v1][root] cumleft[v1][iroot+1] = cumright[v1][iroot] = f(res2[root], res[v2]) iroot = G2i[v2][root] cumleft[v2][iroot+1] = cumright[v2][iroot] = f(res2[root], res[v1]) for i in range(len(G[root])): cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i]) cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1]) else: for i in range(len(G[root])): cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i]) cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1]) for v in G[root]: iroot = G2i[v][root] iv = G2i[root][v] cumleft[v][iroot+1] = cumright[v][iroot] = merge(res2[v], [cumleft[root][iv], cumright[root][-iv-1]]) # 根から近い順に更新 for v in order[1:]: par = prev[v] ipar = G2i[v][par] iv = G2i[par][v] if len(G[par])==1: res_par = f0(_, _) else: res_par = merge(_, [cumleft[par][iv], cumright[par][iv+1]]) cumleft[v][ipar+1] = cumright[v][ipar] = res_par if len(G[v])==1: res2[v] = f_last(res2[v], [res_par]) elif len(G[v])==2: for v2 in G[v]: if v2 != par: break res2[v] = f_last(res2[v], [res_par, res[v2]]) iv2 = G2i[v][v2] cumleft[v][iv2+1] = cumright[v][iv2] = res[v2] for i in range(len(G[v])): cumleft[v][i+1] = g(cumleft[v][i+1], cumleft[v][i]) cumright[v][-i-2] = g(cumright[v][-i-2], cumright[v][-i-1]) else: for v2 in G[v]: if v2 != par: iv2 = G2i[v][v2] cumleft[v][iv2+1] = cumright[v][iv2] = res[v2] res2[v] = f_last(res2[v], [cumleft[v][i+1] for i in range(len(G[v]))]) for i in range(len(G[v])): cumleft[v][i+1] = g(cumleft[v][i+1], cumleft[v][i]) cumright[v][-i-2] = g(cumright[v][-i-2], cumright[v][-i-1]) return res2
バージョン2:左右累積バージョン(微修正版)
バージョン1とほぼ同じだが、一度木DPを終えて他のノードで再計算を行うとき、累積値をマージする関数merge2を定義可能にしたもの。EDPC-Vを解くために使用。
引数の設定例
EDPC-V
def f0(a, b): return e def f(a, b): return b+1 def merge(a, res): out = 1 for r in res: out *= r+1 out %= M return out def merge2(a,res): return res[0]*res[1]%M f_last = merge e = 1 def g(a, b): return (a+1)*b%M
ソースコード
def rerooting(root, f0, f, merge, f_last, e, g, merge2): N = len(G) # 左右から累積するときにインデックスへの変換用 G2i = [{} for i in range(N)] for i in range(N): G[i].sort() G2i[i] = {v:j for j,v in enumerate(G[i])} res, order, d, prev = tree_dp(root, f, f0, merge, f_last, e) # 左右からの累積用の配列 cumleft = [[e for _ in range(len(G[i])+1)] for i in range(N)] cumright = [[e for _ in range(len(G[i])+1)] for i in range(N)] res2 = [e]*N # 根の結果は同じ res2[root] = res[root] # 根の累積 for v in G[root]: iv = G2i[root][v] cumleft[root][iv+1] = cumright[root][iv] = res[v] # 根の子ノードを先に計算 if len(G[root])==1: v = G[root][0] iroot = G2i[v][root] cumleft[v][iroot+1] = cumright[v][iroot] = f0(_, _) for i in range(len(G[root])): cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i]) cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1]) elif len(G[root])==2: v1, v2 = G[root] iroot = G2i[v1][root] cumleft[v1][iroot+1] = cumright[v1][iroot] = f(res2[root], res[v2]) iroot = G2i[v2][root] cumleft[v2][iroot+1] = cumright[v2][iroot] = f(res2[root], res[v1]) for i in range(len(G[root])): cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i]) cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1]) else: for i in range(len(G[root])): cumleft[root][i+1] = g(cumleft[root][i+1], cumleft[root][i]) cumright[root][-i-2] = g(cumright[root][-i-2], cumright[root][-i-1]) for v in G[root]: iroot = G2i[v][root] iv = G2i[root][v] cumleft[v][iroot+1] = cumright[v][iroot] = merge2(res2[v], [cumleft[root][iv], cumright[root][-iv-1]]) # 根から近い順に更新 for v in order[1:]: par = prev[v] ipar = G2i[v][par] iv = G2i[par][v] if len(G[par])==1: res_par = f0(_, _) else: res_par = merge2(_, [cumleft[par][iv], cumright[par][iv+1]]) cumleft[v][ipar+1] = cumright[v][ipar] = res_par if len(G[v])==1: res2[v] = f_last(res2[v], [res_par]) elif len(G[v])==2: for v2 in G[v]: if v2 != par: break res2[v] = f_last(res2[v], [res_par, res[v2]]) iv2 = G2i[v][v2] cumleft[v][iv2+1] = cumright[v][iv2] = res[v2] for i in range(len(G[v])): cumleft[v][i+1] = g(cumleft[v][i+1], cumleft[v][i]) cumright[v][-i-2] = g(cumright[v][-i-2], cumright[v][-i-1]) else: for v2 in G[v]: if v2 != par: iv2 = G2i[v][v2] cumleft[v][iv2+1] = cumright[v][iv2] = res[v2] res2[v] = f_last(res2[v], [cumleft[v][i+1] for i in range(len(G[v]))]) for i in range(len(G[v])): cumleft[v][i+1] = g(cumleft[v][i+1], cumleft[v][i]) cumright[v][-i-2] = g(cumright[v][-i-2], cumright[v][-i-1]) return res2
バージョン3:逆演算が定義できる場合に、全体の累積値から一部の累積値を計算するバージョン
引数に入れる関数の説明
- f0(a, b): 次数が1のノード(葉)の演算。
- f(a,b) -> c: 次数が2のノードiの演算。aはノードi自身のdp配列、bは子ノードのdp配列。
- merge(a,b) -> c: 字数が3以上のノードiの演算。aはノードi自身のdp配列、bは子ノードのdp配列。
- f_last(a,b): 根で行う最後の演算。
- e: 単位元
- prod: 累積をとるときの演算。二項演算が+ならsum, maxならmax, *ならprod
- inv: 累積から引くときの演算。二項演算の逆演算。+なら-, *なら/もしくはmodinv
- process: 累積から最後のノードを加えるときの演算。ほぼ同じか、少し変わるだけ。
mergeとf_last, mergeとprodは同じ場合が多いが一応分けている。
引数の設定例
ABC 160 F
# mod逆元を計算する関数 def modinv(x): return pow(x, mod-2, mod) # 階乗の前計算 fac=[1 for i in range(2*N+10)] for i in range(2,2*N+10): fac[i] = fac[i-1]*i % mod # a, bはdp[v]を表す # dp[v] = [t,val]: t: 部分木のサイズ、val: 求めたい値 def f0(a, b): return e[0]+1,e[1] def f(a, b): return b[0]+1,b[1] # resはdp配列のリスト[dp[v1], dp[v2], ...] def merge(a, res): treesize = 0 denom = 1 val = 1 for t,v in res: treesize += t denom *= fac[t] denom %= mod val *= v val %= mod return treesize + 1, fac[treesize]*modinv(denom)*val%mod f_last = merge e = (0,1) def prod(res): treesize = 0 denom = 1 val = 1 for t,v in res: treesize += t denom *= fac[t] denom %= mod val *= v val %= mod return treesize, fac[treesize]*modinv(denom)*val%mod def inv(a,b): t0,v0 = a t1,v1 = b nom = fac[t1]*fac[t0-t1]%mod denom = fac[t0]*v1%mod out = t0-t1,v0*(nom*modinv(denom)%mod)%mod return out def process(a): return a[0]+1,a[1]
ソースコード
def bfs(root=0): N = len(G) visited = [0]*N visited[root] = 1 q = [root] d = [0]*N prev = [-1]*N for v in q: for v2 in G[v]: if visited[v2]: continue visited[v2] = 1 d[v2] = d[v] + 1 q.append(v2) prev[v2] = v return d, prev def tree_dp(root, f, f0, merge, f_last, e): N = len(G) d, prev = bfs(root) order, _ = zip(*sorted(enumerate(d),key=lambda x:x[1])) res = [e for _ in range(N)] for v in reversed(order[1:]): if len(G[v])==1: res[v] = f0(res[v], e) elif len(G[v])==2: for v2 in G[v]: if v2 != prev[v]: break res[v] = f(res[v], res[v2]) else: res_tmp = [res[v2] for v2 in G[v] if v2 != prev[v]] res[v] = merge(res[v], res_tmp) res_tmp = [res[v2] for v2 in G[root]] res[root] = f_last(res[root], res_tmp) return res, order, d, prev def rerooting2(root, f0, f, merge, f_last, e, prod, inv, process): N = len(G) res, order, d, prev = tree_dp(root, f, f0, merge, f_last, e) # 左右からの累積用の配列 SUM = [e for _ in range(N)] res2 = [e]*N # 根の結果は同じ res2[root] = res[root] # 根の累積 SUM[root] = prod([res[v] for v in G[root]]) # 根から近い順に更新 for v in order[1:]: subtracted = inv(SUM[prev[v]], res[v]) if len(G[prev[v]])==1: res_par = f0(_, _) else: res_par = process(subtracted) SUM[v] = prod([res[v2] for v2 in G[v] if v2 != prev[v]]+[res_par]) res2[v] = process(SUM[v]) return res2