iT邦幫忙

2

資結經典題目: 用heap找前k個最小的數對

c++

參考題目: LeetCode 373. Find K Pairs with Smallest Sums
這一題是說有 nums1, nums2 兩個排序好的陣列,
考慮所有的數對(u,v),
u是nums1的數字,
v是nums2的數字,
數對的總和定義為u+v,
求前k個總和最小的數對。

範例測資

Input: nums1 = [1,7,11], nums2 = [2,4,6], k = 3
Output: [[1,2],[1,4],[1,6]] 
說明: 所有數對由小排到大順序為: [1,2],[1,4],[1,6],[7,2],[7,4],[11,2],[7,6],[11,4],[11,6]

小馬覺得這是資料結構運用「heap」的好題目,
要解開此題不難,
但要有效率的解開可是一大難題

假設nums1有n個元素,nums2有m個元素,k可能比nm相乘還要小很多

方法一: 暴力解,時間O(nmlog(nm))

若nums1有n個元素,nums2有m個元素,
那麼全部會有nm個數對,
直接將nm個數對進行排序,
需要O(nmlog(nm))的時間,
再將排序的結果取前k個出來即可

但是這樣便有個問題,
假設k遠小於n, m的情形,
將所有數對拿來排序是相當耗時的
(例如n=10000, m=10000, k=3)
我們需要更有效率的方法

方法二: 找適當的候選人放入heap,時間O(klog(k))

這邊我們畫一個表格會比較清楚哪些數對是可能的前k小候選人
nums1 = [1,7,11], nums2 = [2,4,6]為例,

https://ithelp.ithome.com.tw/upload/images/20200407/20117114wJmHmVqtqr.png

首先思考一下,第一小的數對是誰呢?
很簡單,數對總和最小,那當然是取nums1最小的數字與nums2最小的數字,
所以最小的數對一定是[1,2]

那第二小的呢?
有沒有可能是[7,4]呢?
不可能,因為至少有[1,4]和[7,2]兩個數對比它小

事實上,若[nums1[i],nums2[j]]是第k小的數對,
那麼第k+1小的數對候選人便會添加[nums1[i+1],nums2[j]]、[nums1[i],nums2[j+1]]兩個人,
即表格上[nums1[i],nums2[j]]右方一格及下方一格的位置,
右下角方向的其它格子都不可能是候選人

譬如說我要依序找nums1 = [1,7,11], nums2 = [2,4,6]前四小的數對

step1:
前1小數對: [1,2]
第2小候選人: [1,4], [7,2]

step2:
前2小數對: [1,2], [1,4]
第3小候選人: [7,2], [1,6], [7,4]

step3:
前3小數對: [1,2], [1,4], [1,6]
第4小候選人: [7,2], [7,4], [7,6]

step4:
前4小數對: [1,2], [1,4], [1,6], [7,2]
第5小候選人: [7,4], [7,6], [11, 2] (注意[7,4]添加過了,不要重複加入)

每次從候選人中找出最小的數字,
再增加至多二個次小的候選人進入heap,
像這種需要反覆「添加」、「取最小值」的動作,
便很適合使用min-heap,
因為min-heap「添加」、「取最小值」的動作只需要O(log(k))的時間

完整範例程式碼(用自己寫的heap結構所以程式很長)

#include <iostream>
#include <vector>
#include <climits>
using namespace std;

class Node
{
public:
    Node() = default;
    Node(int x, int y, vector<int> d): index_x(x), index_y(y), data(d){};
    int index_x, index_y;
    vector<int> data;
    
    //因為陣列元素可能重複,故數對和相同時,定義x座標小的較小避免順序錯亂
    bool operator<(const Node& other){
        if(data[0]+data[1] < other.data[0]+other.data[1])
            return true;
        if(data[0]+data[1] > other.data[0]+other.data[1])
            return false;
        return index_x<other.index_x;
    }
    bool operator>(const Node& other){
        if(data[0]+data[1] > other.data[0]+other.data[1])
            return true;
        if(data[0]+data[1] < other.data[0]+other.data[1])
            return false;
        return index_x>other.index_x;
    }
};

// 0-index base Heap
template <typename T> class Heap
{
private:
    void heapify();
public:
    Heap() = default;
    Heap(vector<T> vec):arr(vec){heapify();};
    vector<T> arr;
    void insert(T node);
    T* get_min();
    void extract_min();
};

//調整heap使它滿足parent都比child小的規則(總時間: O(n))
template <typename T>
void Heap<T>::heapify()
{
    //這邊與extract-min的做法一樣,從下往上,將有小孩的節點往下交換
    for (int i = (arr.size()-1)/2; i>=0; i--) {
        int idx=i;
        //若節點是有小孩的就繼續迴圈
        while(idx*2+1 <arr.size()){
            int small_chlid_idx  = idx*2+1; //判斷小孩節點比較小的編號
            if (small_chlid_idx + 1 < arr.size() && arr[small_chlid_idx + 1] < arr[small_chlid_idx])
                small_chlid_idx++;
            if(arr[idx] > arr[small_chlid_idx])
                swap(arr[idx], arr[small_chlid_idx]);
            idx = small_chlid_idx;
        }
    }
}

template <typename T>
void Heap<T>::insert(T data)
{
    arr.push_back(data);
    int idx = arr.size()-1;
    while(idx>0 && arr[idx]<arr[(idx-1)/2]){
        swap(arr[idx], arr[(idx-1)/2]);
        idx = (idx-1)/2;
    }
}

template <typename T>
T* Heap<T>::get_min()
{
    return !arr.empty()?&arr[0]:NULL;
}

//拿走最小的元素,若heap為空不做事
template <typename T>
void Heap<T>::extract_min()
{
    if(arr.empty()){
        return;
    }
    T small = arr[0];
    swap(arr[0], arr[arr.size()-1]);
    arr.pop_back(); //清掉最後一個元素(即最小的那個)
    int idx = 0;
    //若節點是有小孩的就繼續迴圈
    while(idx*2+1 <arr.size()){
        int small_chlid_idx  = idx*2+1; //判斷小孩節點比較小的編號
        if (small_chlid_idx + 1 < arr.size() && arr[small_chlid_idx + 1] < arr[small_chlid_idx])
            small_chlid_idx++;
        if(arr[idx] > arr[small_chlid_idx])
            swap(arr[idx], arr[small_chlid_idx]);
        idx = small_chlid_idx;
    }
}

//函數功能: nums1, nums2 是兩個排序好的陣列(數字可重複),回傳前k個總和最小的數對
vector<vector<int>> kSmallestPairs(vector<int>& nums1, vector<int>& nums2, int k) {
    
    if(nums1.empty()||nums2.empty())
        return vector<vector<int>>{}; //若其中一個陣列為空直接回傳
    if(nums1.size()*nums2.size()<INT_MAX) //避免nums1.size()*nums2.size()相乘溢位
        k = std::min(k, (int)(nums1.size()*nums2.size())); //k有可能比陣列大小更大,此時回傳陣列所有數對即可
    
    vector<int> boudary(nums2.size(), -1); //每一列使用的右邊界
    boudary[0] = 0;
    vector<vector<int>> result = {{nums1[0],nums2[0]}};
    Node last_use(0,0, {nums1[0],nums2[0]}); //最後一個用到的Node
    Heap<Node> h;
    
    for (int i = 0; i < k-1; i++) {
        int one = last_use.index_x, two = last_use.index_y;
        vector<vector<int>> candidates = {{one+1, two}, {one, two+1}};
        for(auto c: candidates){
            // 要檢查候選人是否超出邊界或重複添加
            if(c[0]<nums1.size() && c[1] < nums2.size() && c[0]> boudary[c[1]]){
                h.insert(Node(c[0],c[1], {nums1[c[0]], nums2[c[1]]}));
                boudary[c[1]] = c[0];
            }
        }
        last_use = *h.get_min();
        result.push_back(last_use.data);
        h.extract_min();
    }
    return result;
}

int main()
{
    vector<int> vec1={1,3,3,5};
    vector<int> vec2={2,4,6};
    int k = 13;

    vector<vector<int>> result = kSmallestPairs(vec1,vec2,k);
    for (int i = 0; i < result.size();i++) {
        std::cout << "["<<result[i][0]<< ", "<<result[i][1]<<"]";
        std::cout << (i < result.size()-1? ", ":"");
    }
    std::cout << std::endl;
    return 0;
}

尚未有邦友留言

立即登入留言