iT邦幫忙

2025 iThome 鐵人賽

DAY 27
0
Software Development

學習C++必備!30 天演算法入門到進階 + CSES 與 Leetcode 實戰系列 第 27

Day 27:Rabin–Karp + 滾動雜湊(Rolling Hash)

  • 分享至 

  • xImage
  •  

一、學習目標

  • 了解多項式滾動雜湊(Polynomial Rolling Hash) 的定義、前綴雜湊與 O(1) 取子字串雜湊。
  • 熟悉單/雙雜湊、避免碰撞的實務選參(base、mod),與安全的模運算寫法。
  • 能用雜湊解:子字串相等判斷、回文檢查、重複子字串長度二分、多模式匹配(Rabin–Karp)。
  • 知道何時用 KMP(精確匹配、理論無碰撞)與何時用 Hash(多次查詢、快速等值比較、二分長度)。

二、觀念說明

滾動雜湊

通用模板

struct RollingHash {
    using ull = unsigned long long;
    static const uint64_t MOD = 1000000007ULL;   // 可換成 1e9+9
    static const uint64_t BASE = 131ULL;         // 常見:131, 911382323 等

    // 內部使用 1-index:s[1..n]
    int n;
    vector<uint64_t> H;  // 前綴雜湊 H[i]
    vector<uint64_t> P;  // 冪表 P[i] = BASE^i

    RollingHash() {}
    RollingHash(const string& s) { init(s); }

    void init(const string& s) {
        n = (int)s.size();
        H.assign(n + 1, 0);
        P.assign(n + 1, 1);
        for (int i = 1; i <= n; ++i) {
            uint64_t x = (unsigned char)s[i-1] + 1;// 避免 0
            P[i] = (P[i-1] * BASE) % MOD;   // 冪累乘
            H[i] = (H[i-1] * BASE + x) % MOD;  // 前綴遞推
        }
    }

    // 取得 s[l..r] 的雜湊(1-index, O(1))
    uint64_t get(int l, int r) const {
        if (l > r) return 0;
        uint64_t res = (H[r] + MOD - (H[l-1] * P[r-l+1]) % MOD) % MOD;
        return res;
    }
};

子字串是否相等(多次查詢)

// 回答多筆「(a,b,len) 兩段是否相等」的查詢
vector<string> substringEqualQueries(const string& s,
                                     const vector<tuple<int,int,int>>& qs) {
    RollingHash rh(s);
    vector<string> ans; ans.reserve(qs.size());
    for (auto [a,b,len] : qs) {
        auto h1 = rh.get(a, a+len-1);
        auto h2 = rh.get(b, b+len-1);
        ans.push_back(h1 == h2 ? "Yes" : "No");
    }
    return ans;
}

區間是否回文(正反各建一套)

// 檢查 s[l..r] 是否回文(1-index)
struct PalChecker {
    RollingHash fwd, rev; int n;
    PalChecker(const string& s) {
        n = (int)s.size();
        fwd.init(s);
        string t = s; reverse(t.begin(), t.end());
        rev.init(t);
    }
    bool isPal(int l, int r) {
        // 反串對應區間: [n-r+1, n-l+1]
        auto a = fwd.get(l, r);
        auto b = rev.get(n - r + 1, n - l + 1);
        return a == b;
    }
};

三、CSES實戰演練

題目1:Pattern Positions

原題:
https://cses.fi/problemset/task/2104

題意:
給長字串 T 與樣式 P,輸出 P 在 T 中所有起始位置(1-index),可重疊。

解題思路:

  • 先對 T 建雙前綴雜湊;對 P 也取雙雜湊。
  • 逐一檢查長度 |P| 的視窗 T[i..i+m-1] 的雙雜湊是否等於 P 的雜湊。
  • 雙模將碰撞機率壓到極低;必要可再加一次實體比對(通常不需)。
#include <bits/stdc++.h>
using namespace std;
using ull = unsigned long long;

struct RH {
    static const ull M1=1000000007ULL, M2=1000000009ULL;
    static const ull B1=131ULL, B2=137ULL;
    vector<ull> p1,p2,h1,h2; int n;
    RH(){} RH(const string& s){ init(s); }
    void init(const string& s){
        n=s.size(); p1.assign(n+1,1); p2.assign(n+1,1);
        h1.assign(n+1,0); h2.assign(n+1,0);
        for(int i=1;i<=n;i++){
            p1[i]=(p1[i-1]*B1)%M1; p2[i]=(p2[i-1]*B2)%M2;
            ull x=(unsigned char)s[i-1]+1;
            h1[i]=(h1[i-1]*B1+x)%M1; h2[i]=(h2[i-1]*B2+x)%M2;
        }
    }
    pair<ull,ull> get(int l,int r){ // 1-index inclusive
        ull x1=(h1[r]+M1-(h1[l-1]*p1[r-l+1])%M1)%M1;
        ull x2=(h2[r]+M2-(h2[l-1]*p2[r-l+1])%M2)%M2;
        return {x1,x2};
    }
};

int main(){
    ios::sync_with_stdio(false); cin.tie(nullptr);
    string T,P; if(!(cin>>T>>P)) return 0;
    int n=T.size(), m=P.size();
    RH ht(T), hp(P);
    auto pat=hp.get(1,m);
    vector<int> pos;
    for(int i=1;i+m-1<=n;i++)
        if(ht.get(i,i+m-1)==pat) pos.push_back(i);
    for(size_t i=0;i<pos.size();i++)
        cout<<pos[i]<<(i+1==pos.size()?'\n':' ');
    if(pos.empty()) cout<<"\n";
    return 0;
}

題目2:Substring Order I

原題:
https://cses.fi/problemset/task/2108

題意:
給字串 s 與整數 k,求 s 的所有不同子字串中,字典序第 k 小的是哪一個。

解題思路:

  • 用後綴陣列(SA)或簡化的後綴排序得到所有後綴的字典序順序。
  • 每個後綴 suffix[i] 能貢獻的「新子字串數」= len - sa[i] - LCP[i]。
  • 走訪 SA,累加新子字串數,當累加超過 k 時,答案在當前後綴的「LCP[i] + (k - sum_before) 長度」的前綴。
  • 為了子字串比較與擷取穩定,我們用雙雜湊實作 get(l,r);SA/LCP 可用常見 O(nlogn) 建法。
#include <bits/stdc++.h>
using namespace std;
using ull = unsigned long long;

struct RH {
    static const ull M1=1000000007ULL, M2=1000000009ULL;
    static const ull B1=131ULL, B2=137ULL;
    vector<ull> p1,p2,h1,h2; int n;
    RH(){} RH(const string& s){ init(s); }
    void init(const string& s){
        n=s.size(); p1.assign(n+1,1); p2.assign(n+1,1);
        h1.assign(n+1,0); h2.assign(n+1,0);
        for(int i=1;i<=n;i++){
            p1[i]=(p1[i-1]*B1)%M1; p2[i]=(p2[i-1]*B2)%M2;
            ull x=(unsigned char)s[i-1]+1;
            h1[i]=(h1[i-1]*B1+x)%M1; h2[i]=(h2[i-1]*B2+x)%M2;
        }
    }
    pair<ull,ull> get(int l,int r){
        ull x1=(h1[r]+M1-(h1[l-1]*p1[r-l+1])%M1)%M1;
        ull x2=(h2[r]+M2-(h2[l-1]*p2[r-l+1])%M2)%M2;
        return {x1,x2};
    }
};

vector<int> build_sa(const string& s){
    int n=s.size(),N=max(256,n)+1;
    vector<int> sa(n),rnk(n),tmp(n),cnt(N);
    for(int i=0;i<n;i++) sa[i]=i,rnk[i]=(unsigned char)s[i];
    for(int k=1;k<n;k<<=1){
        auto key2 = [&](int i){ return i+k<n?rnk[i+k]:-1; };
        iota(cnt.begin(),cnt.end(),0); fill(cnt.begin(),cnt.end(),0);
        // sort by second key (counting sort by rnk[i+k])
        int mx=max(256,n)+1;
        vector<int> sa2(n);
        for(int i=0;i<n;i++){ int v=key2(i)+1; if(v<0)v=0; cnt[v]++; }
        for(int i=1;i<mx;i++) cnt[i]+=cnt[i-1];
        for(int i=n-1;i>=0;i--){ int v=key2(sa[i])+1; if(v<0)v=0; sa2[--cnt[v]]=sa[i]; }
        // sort by first key (stable by rnk[i])
        fill(cnt.begin(),cnt.begin()+mx,0);
        for(int i=0;i<n;i++) cnt[rnk[i]+1]++;
        for(int i=1;i<mx;i++) cnt[i]+=cnt[i-1];
        for(int i=n-1;i>=0;i--) sa[--cnt[rnk[sa2[i]]+1]]=sa2[i];
        tmp[sa[0]]=0;
        for(int i=1;i<n;i++)
            tmp[sa[i]] = tmp[sa[i-1]] + (rnk[sa[i-1]]!=rnk[sa[i]] || key2(sa[i-1])!=key2(sa[i]));
        rnk.swap(tmp);
        if(rnk[sa[n-1]]==n-1) break;
    }
    return sa;
}

vector<int> build_lcp(const string& s, const vector<int>& sa){
    int n=s.size(); vector<int> rnk(n),lcp(n-1,0);
    for(int i=0;i<n;i++) rnk[sa[i]]=i;
    int h=0;
    for(int i=0;i<n;i++){
        int k=rnk[i];
        if(k==n-1){ h=0; continue; }
        int j=sa[k+1];
        while(i+h<n && j+h<n && s[i+h]==s[j+h]) h++;
        lcp[k]=h;
        if(h) h--;
    }
    return lcp;
}

int main(){
    ios::sync_with_stdio(false); cin.tie(nullptr);
    string s; long long k;
    if(!(cin>>s>>k)) return 0;
    int n=s.size();
    auto sa = build_sa(s);
    auto lcp = build_lcp(s, sa);
    RH rh(s);

    auto count_new = [&](int i){ // 後綴 i 的新子字串數
        int suf_len = n - sa[i];
        int prev_lcp = (i==0?0:lcp[i-1]);
        return (long long)suf_len - prev_lcp;
    };

    long long acc = 0;
    for(int i=0;i<n;i++){
        long long add = count_new(i);
        if(acc + add >= k){
            int prev_lcp = (i==0?0:lcp[i-1]);
            int need = prev_lcp + (int)(k - acc); // 需要的長度
            cout << s.substr(sa[i], need) << "\n";
            return 0;
        }
        acc += add;
    }
    cout << "\n";
    return 0;
}

四、Leetcode實戰演練

題目1:Maximum Length of Repeated Subarray

原題:
https://leetcode.com/problems/maximum-length-of-repeated-subarray/description/

題意:
給兩個整數陣列 A、B,求它們的最長共同連續子陣列長度。

解題思路:

  • 對長度 L 二分;把 A 的所有長度 L 子陣列雜湊丟到表裡,再掃 B 的長度 L 子陣列是否有相同雜湊(必要時實體比對)。
  • 雜湊時把元素轉成 uint64_t 累乘;也可將元素平移(加偏移)避免負數。
class Solution {
    using ull = unsigned long long;
    static const ull M1=1000000007ULL, M2=1000000009ULL;
    static const ull B1=911382323ULL, B2=972663749ULL; // 大基數更穩
public:
    int findLength(vector<int>& A, vector<int>& B) {
        int n=A.size(), m=B.size();
        auto ok = [&](int L)->bool{
            if(L==0) return true;
            vector<ull> p1(max(n,m)+1,1), p2(max(n,m)+1,1);
            vector<ull> h1a(n+1,0), h2a(n+1,0), h1b(m+1,0), h2b(m+1,0);
            for(int i=1;i<=max(n,m);i++){
                p1[i]=(p1[i-1]*B1)%M1; p2[i]=(p2[i-1]*B2)%M2;
            }
            auto norm = [](long long x)->ull{ return (ull)(x + (1LL<<31)); };
            for(int i=1;i<=n;i++){
                ull x=norm(A[i-1]);
                h1a[i]=(h1a[i-1]*B1 + x)%M1;
                h2a[i]=(h2a[i-1]*B2 + x)%M2;
            }
            for(int i=1;i<=m;i++){
                ull x=norm(B[i-1]);
                h1b[i]=(h1b[i-1]*B1 + x)%M1;
                h2b[i]=(h2b[i-1]*B2 + x)%M2;
            }
            auto getA=[&](int l,int r){
                ull x1=(h1a[r]+M1-(h1a[l-1]*p1[r-l+1])%M1)%M1;
                ull x2=(h2a[r]+M2-(h2a[l-1]*p2[r-l+1])%M2)%M2;
                return (x1<<32)^x2;
            };
            auto getB=[&](int l,int r){
                ull x1=(h1b[r]+M1-(h1b[l-1]*p1[r-l+1])%M1)%M1;
                ull x2=(h2b[r]+M2-(h2b[l-1]*p2[r-l+1])%M2)%M2;
                return (x1<<32)^x2;
            };
            unordered_map<unsigned long long, vector<int>> mp; mp.reserve(n*2);
            for(int i=1;i+L-1<=n;i++) mp[getA(i,i+L-1)].push_back(i);
            for(int j=1;j+L-1<=m;j++){
                auto key=getB(j,j+L-1);
                if(mp.find(key)!=mp.end()) return true; // 可再做實體比對保險
            }
            return false;
        };
        int lo=0, hi=min((int)A.size(), (int)B.size()), ans=0;
        while(lo<=hi){
            int mid=(lo+hi)/2;
            if(ok(mid)){ ans=mid; lo=mid+1; } else hi=mid-1;
        }
        return ans;
    }
};

題目2:Find Substring With Given Hash Value

原題:
https://leetcode.com/problems/find-substring-with-given-hash-value/

題意:
給字串 s(小寫英文字母),以及整數 power, modulo, k, hashValue。
定義子字串 s[i..i+k-1] 的雜湊(從右到左、a→1),找出最左邊滿足 hash(i) == hashValue 的長度 k 子字串。

解題思路:

  • 依題目的右到左定義,用滾動雜湊從右往左滑動長度 k 的視窗。
  • 維護 cur = hash(i+1) 轉到 hash(i) 的遞推:
  • 先把 cur 乘上 power,再加上新加入的右到左最低次方字元值,最後減去移出視窗的那個字元對應的 power^k 貢獻。
  • 每次比較 cur % modulo 是否等於 hashValue,記錄最左位置。
class Solution {
public:
    string subStrHash(string s, int power, int modulo, int k, int hashValue) {
        int n = (int)s.size();
        auto val = [&](char c){ return (long long)(c - 'a' + 1); };

        long long pk = 1;                  // power^k % modulo
        for (int i = 0; i < k; ++i) pk = (pk * power) % modulo;

        long long cur = 0;                 // rolling hash of window s[i..i+k-1] (right-to-left form)
        int best = -1;

        // 初始化最右側長度 k 視窗:i = n-k
        for (int j = n - 1; j >= n - k; --j) {
            cur = (cur * power + val(s[j])) % modulo;
        }
        if (cur == hashValue) best = n - k;

        // 從右往左滑動
        for (int i = n - k - 1; i >= 0; --i) {
            // 移入 s[i],移出 s[i+k]
            // new = (cur * power + val(s[i])) - val(s[i+k]) * power^k
            cur = (cur * power + val(s[i])) % modulo;
            long long out = (val(s[i + k]) * pk) % modulo;
            cur = (cur - out) % modulo;
            if (cur < 0) cur += modulo;

            if (cur == hashValue) best = i;   // 往左更新,保證取到最左
        }

        return s.substr(best, k);
    }
};

上一篇
Day 26:KMP 演算法(高效子字串搜尋)
下一篇
Day 28:樹的基本性質與遍歷
系列文
學習C++必備!30 天演算法入門到進階 + CSES 與 Leetcode 實戰28
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言