Cross-Batch Memory for Embedding Learning
這篇是關於deep metric learning的論文,
在這個領域中,如何找到有information的sample是核心問題,
通常我們會使用hard sample mining來找,
目前的方法主要有兩種思路:
第一種思路是在mini-batch內給權重
第二種思路是建構一個更有效的mini-batch採樣方法
然而,這兩種mining方法還是有一個天花板,就是mini-batch的大小,
由於受限於gpu memory大小,每個iteration最多只能看到一個mini-batch的samples。
他們在三個image retrieval dataset上,
實驗發現隨著mini-batch變大,效果(Recall@1)大幅提升。
還有對於不同的pair-based loss也有這種這種現象。
那作者就想有沒有辦法利用之前batch的sample來產生pair,來達到相同加大batch的效果?
這篇最主要的貢獻在這裡,他不只用當前mini-batch來產pair,
而是把過去的 mini-batch 提取的feature也拿過來與當前 mini-batch 作比較,產生pair。
但這樣首先會遇到一個問題,隨著training過程,model提取出的特徵是會改變的,那有辦法比較嗎?
所以作者又做了一個實驗,將feature隨著model training的偏移量,稱之為特徵偏移(Feature Drift)
可以發現,只要大約過了 3K iterations,過去的iterations裡所提取過的特徵,
會逐漸變成為當前模型的一個有效近似。
所以,我們可以把這些feature給存下來用,所佔memory也就幾十MB。
這是他們提出的方法:Cross-Batch Memory(XBM)架構如下:
首先會先train K個epoch,讓feature穩定,然後再使用queue去存特徵,
把最舊的特徵踢掉,新的加入,然後使用整個queue的feature去建pair,之後算loss。
這樣一來就等於是用好幾倍的batchsize去train model
可以看到,XBM 帶來了顯著的效果提升,尤其是在Contrastive Loss上,
在三個數據集上,Recall@1都至少提升10個點。
上圖我們展示了在 SOP 上訓練時的計算資源消耗,
即便把整個訓練集(50K+)的特徵都加載到 XBM,不過只需要 0.2GB。
甚至比直接增大 batch訓練的效果還好。
和SOTA對比,我們在幾個常用的圖像檢索數據庫上,
效果均大幅提升。