ウェーブレット行列
任意区間のK番目の要素を高速に取得するデータ構造
https://ei1333.github.io/library/structure/wavelet/wavelet-matrix.cpp.html
前計算 クエリ
- WaveletMatrix(v): 各要素の高さ v を初期値として構築する.
- access(k): k 番目の要素を返す.
- rank(x, r): 区間 [0,r) に含まれる x の個数を返す.
- kth_smallest(l, r, k): 区間 [l,r) に含まれる要素のうち k 番目(0-indexed) に小さいものを返す.
- kth_largest(l, r, k): 区間 [l,r)に含まれる要素のうち k 番目 (0-indexed) に大きいものを返す.
- range_freq(l, r, lower, upper): 区間 [l,r) に含まれる要素のうち [lower,upper) である要素数を返す.
- prev_value(l, r, upper): 区間 [l,r) に含まれる要素のうち upper の次に小さいものを返す.
- next_value(l, r, lower): 区間 [l,r) に含まれる要素のうち lower の次に大きいものを返す.
#include <bits/stdc++.h> namespace LIB { /** * @brief 完備辞書 (Succinct Indexable Dictionary) **/ struct SuccinctIndexableDictionary { using ll = long long; ll length = 0; ll blocks = 0; std::vector< unsigned > bit, sum; SuccinctIndexableDictionary() = default; SuccinctIndexableDictionary(size_t length) : length(length), blocks((length + 31) >> 5) { bit.assign(blocks, 0U); sum.assign(blocks, 0U); } bool operator[](ll k) { return (bool((bit[k >> 5] >> (k & 31)) & 1)); } ll rank(ll k) { return (ll(sum[k >> 5]) + bitcnt(ll(bit[k >> 5] & ((1U << (k & 31)) - 1)))); } ll rank(bool val, ll k) { return (val ? rank(k) : k - rank(k)); } void set(ll k) { bit[k >> 5] |= 1U << (k & 31); } void build() { sum[0] = 0U; for (ll i = 1; i < blocks; i++) sum[i] = sum[i - 1] + (unsigned)bitcnt(bit[i - 1]); } private: ll bitcnt(ll x) { std::bitset<64> b(x); return b.count(); } }; /* * @brief ウェーブレット行列 (Wavelet Matrix) * @docs docs/wavelet-matrix.md */ template< typename T, long long MAXLOG > struct WaveletMatrix { using ll = long long; ll length = 0; SuccinctIndexableDictionary matrix[MAXLOG] = { 0 }; ll mid[MAXLOG] = { 0 }; WaveletMatrix() = default; WaveletMatrix(vector< T > v) : length(v.size()) { vector< T > l(length), r(length); for (ll level = MAXLOG - 1; level >= 0; level--) { matrix[level] = SuccinctIndexableDictionary(length + 1); ll left = 0, right = 0; for (ll i = 0; i < length; i++) { if (((v[i] >> level) & 1)) matrix[level].set(i), r[right++] = v[i]; else l[left++] = v[i]; } mid[level] = left; matrix[level].build(); v.swap(l); for (ll i = 0; i < right; i++) v[left + i] = r[i]; } } pair< ll, ll > succ(bool f, ll l, ll r, ll level) { return { matrix[level].rank(f, l) + mid[level] * f, matrix[level].rank(f, r) + mid[level] * f }; } T operator[](const ll& k) { return access(k); } // k-th(0-indexed) largest number in v[l,r) T kth_largest(ll l, ll r, ll k) { return kth_smallest(l, r, r - l - k - 1); } // count i s.t. (l <= i < r) && (lower <= v[i] < upper) ll range_freq(ll l, ll r, T lower, T upper) { return range_freq(l, r, upper) - range_freq(l, r, lower); } // v[k] T access(ll k) { T ret = 0; for (ll level = MAXLOG - 1; level >= 0; level--) { bool f = matrix[level][k]; if (f) ret |= T(1) << level; k = matrix[level].rank(f, k) + mid[level] * f; } return ret; } // count i s.t. (0 <= i < r) && v[i] == x ll rank(const T& x, ll r) { ll l = 0; for (ll level = MAXLOG - 1; level >= 0; level--) tie(l, r) = succ((x >> level) & 1, l, r, level); return r - l; } // k-th(0-indexed) smallest number in v[l,r) T kth_smallest(ll l, ll r, ll k) { assert(0 <= k && k < r - l); T ret = 0; for (ll level = MAXLOG - 1; level >= 0; level--) { ll cnt = matrix[level].rank(false, r) - matrix[level].rank(false, l); bool f = cnt <= k; if (f) ret |= T(1) << level, k -= cnt; tie(l, r) = succ(f, l, r, level); } return ret; } // count i s.t. (l <= i < r) && (v[i] < upper) ll range_freq(ll l, ll r, T upper) { ll ret = 0; for (ll level = MAXLOG - 1; level >= 0; level--) { bool f = ((upper >> level) & 1); if (f) ret += matrix[level].rank(false, r) - matrix[level].rank(false, l); tie(l, r) = succ(f, l, r, level); } return ret; } // max v[i] s.t. (l <= i < r) && (v[i] < upper) T prev_value(ll l, ll r, T upper) { ll cnt = range_freq(l, r, upper); return cnt == 0 ? T(-1) : kth_smallest(l, r, cnt - 1); } // min v[i] s.t. (l <= i < r) && (lower <= v[i]) T next_value(ll l, ll r, T lower) { ll cnt = range_freq(l, r, lower); return cnt == r - l ? T(-1) : kth_smallest(l, r, cnt); } }; /** * @note * 説明 https://ei1333.github.io/library/structure/wavelet/wavelet-matrix.cpp.html * 実装例 https://judge.yosupo.jp/submission/72212 * CompressedWaveletMatrix wm(v); * ll a, b, c; cin >> a >> b >> c; vl ans; * ans.push_back(wm.kth_smallest(a, b, c)); **/ template< typename T = long long, long long MAXLOG = 64ll > struct CompressedWaveletMatrix { using ll = long long; WaveletMatrix< ll, MAXLOG > mat; vector< T > ys; CompressedWaveletMatrix(const vector< T >& v) : ys(v) { sort(begin(ys), end(ys)); ys.erase(unique(begin(ys), end(ys)), end(ys)); vector< ll > t(v.size()); for (ll i = 0; i < ll(v.size()); i++) t[i] = get(v[i]); mat = WaveletMatrix< ll, MAXLOG >(t); } T operator[](const ll& k) { return access(k); } T access(ll k) { return ys[mat.access(k)]; } inline ll get(const T& x) { return lower_bound(begin(ys), end(ys), x) - begin(ys); } T kth_smallest(ll l, ll r, ll k) { return ys[mat.kth_smallest(l, r, k)]; } T kth_largest(ll l, ll r, ll k) { return ys[mat.kth_largest(l, r, k)]; } ll range_freq(ll l, ll r, T upper) { return mat.range_freq(l, r, get(upper)); } ll range_freq(ll l, ll r, T lower, T upper) { return mat.range_freq(l, r, get(lower), get(upper)); } ll rank(const T& x, ll r) { auto pos = get(x); if (pos == ys.size() || ys[pos] != x) return 0; return mat.rank(pos, r); } T prev_value(ll l, ll r, T upper) { auto ret = mat.prev_value(l, r, get(upper)); return ret == -1 ? T(-1) : ys[ret]; } T next_value(ll l, ll r, T lower) { auto ret = mat.next_value(l, r, get(lower)); return ret == -1 ? T(-1) : ys[ret]; } }; }
実装例
https://judge.yosupo.jp/submission/72212
https://atcoder.jp/contests/abc234/submissions/28422004