【Python】非再帰抽象化 Segment Tree (通常バージョン、非可換対応バージョン)
バージョン1:可換モノイド用
機能
通常のセグメント木と同様、一要素ごとの更新、区間における演算結果を取得できる。抽象化してあり、二演算と単位元のペアを指定することで様々な可換モノイドに対応できる。
(二項演算と単位元の例)
二項演算 | 単位元 |
---|---|
max | 0 (負の数を含む場合は-INF) |
min | INF |
+ | 0 |
* | 1 |
+ (mod) | 0 |
* (mod) | 1 |
xor | 0 |
gcd | 0 |
lcm | 1 |
使い方
インスタンスの作成
- (例)maxを計算するセグ木を作りたい場合:seg = SegTree(N=N, data=arr, f=max, e=0)
- (例)* (mod) を計算するセグ木を作りたい場合:seg = SegTree(N=N, data=arr, f=lambda a,b: a * b % mod, e=1)
- (例) xorを計算するセグ木を作りたい場合:seg = SegTree(N=N, data=arr, f=lambda a,b: a ^ b, e=0)
インデックスi (0-indexed)の要素の値をxに変更
- seg.set_val(i, x)
区間[l, r)(0-indexedかつ半開区間。つまりpythonのrange(l, r)と同じ)の計算結果を出力。
- seg.prod(l, r)
使用例
参考にした記事など
class SegTree: def __init__(self, N, data=None, f=max, e=0): self.N = N if not data: data = [e for _ in range(N)] self.data = [e for _ in range(N-1)]+data self.f = f self.e = e for i in range(2*N-3, 0, -2): self.data[i>>1] = self.f(self.data[i],self.data[i+1]) def set_val(self, i, x): i += self.N-1 self.data[i] = x while i: if i%2==0: i-= 1 self.data[i>>1] = self.f(self.data[i],self.data[i+1]) i >>= 1 # [L, R) def prod(self, l, r): if r-l==1: return self.data[l+self.N-1] l += self.N-1 r+=self.N-2 out = self.e while r > l: if l%2==0: out = self.f(out, self.data[l]) l += 1 if r%2: out = self.f(out, self.data[r]) r -= 1 if r-l==1: out = self.f(out, self.data[l>>1]) break elif r==l: out = self.f(out, self.data[l]) break l -= 1 l >>= 1 r -= 1 r >>= 1 return out
バージョン2: 非可換モノイド対応版
機能
可換モノイドに加えて、非可換モノイドにも対応(つまり、配列の前後の順番を変えずに計算することができる)したバージョン。順番を前後させると計算結果が変わってしまう二項演算の例として、-, / (割り算), ax+b (アフィン変換), 行列の積などがある。
使い方
バージョン1と同じ。
使用例
- アフィン変換(タコヤキオイシクナール) atcoder.jp また、バージョン1と同じく通常の可換モノイドにも使うことができる。
- (+) https://atcoder.jp/contests/practice2/submissions/22268209
- (xor) https://atcoder.jp/contests/abc185/submissions/22312435
class SegTree: def __init__(self, N, data=None, f=max, e=0): self.N = N if not data: data = [e for _ in range(N)] self.data = [e for _ in range(N-1)]+data self.f = f self.e = e for i in range(2*N-3, 0, -2): self.data[i>>1] = self.f(self.data[i],self.data[i+1]) def set_val(self, i, x): i += self.N-1 self.data[i] = x while i: if i%2==0: i-= 1 self.data[i>>1] = self.f(self.data[i],self.data[i+1]) i >>= 1 # [L, R) # 下から上げていって、要求区間が対象区間を完全に含んでいれば区間のインデックスをスタックに突っ込む。 # 上に上がるにつれて、徐々に範囲を狭めていく def prod(self, l, r): if r-l==1: return self.data[l+self.N-1] l += self.N-1 r+=self.N-2 left,right=[],[] while r > l: if l%2==0: left.append(l) l += 1 if r%2: right.append(r) r -= 1 if r-l==1: left.append(l>>1) break elif r==l: left.append(l) break l -= 1 l >>= 1 r -= 1 r >>= 1 res = left+right[::-1] out = self.data[res[0]] for i in range(1,len(res)): out = self.f(out, self.data[res[i]]) return out