深度學習訓練數據不平衡問題,怎麼解決?
本文為雷鋒字幕組編譯的技術博客,原標題 Deep learning unbalanced training data ? Solve it like this,作者為 Shubrashankh Chatterjee 。
翻譯 | 葉青 整理 | MY
當我們解決任何機器學習問題時,我們面臨的最大問題之一是訓練數據不平衡。不平衡數據的問題在於學術界對於相同的定義、含義和可能的解決方案存在分歧。我們將嘗試用圖像分類問題來解開訓練數據中不平衡類別的奧秘。
不平衡類會有什麼問題?
在一個分類問題中,如果在所有你想要預測的類別里有一個或者多個類別的樣本量非常少,那你的數據也許就面臨不平衡類別的問題。
舉例
1.欺詐預測(欺詐的數量遠遠小於真實交易的數量)
2.自然災害預測(不好的事情遠遠小於好的事情)
3.在圖像分類中識別惡性腫瘤(訓練樣本中含有腫瘤的圖像遠比沒有腫瘤的圖像少)
為什麼這是個問題呢?
不平衡類別會造成問題有兩個主要原因:
1.對於不平衡類別,我們不能得到實時的最優結果,因為模型/演算法從來沒有充分地考察隱含類。
2.它對驗證和測試樣本的獲取造成了一個問題,因為在一些類觀測極少的情況下,很難在類中有代表性。
解決這個問題有哪些不同方法?
現在有三種主要建議的方法,它們各有利弊:
1.欠採樣- 隨機刪除觀測數量足夠多的類,使得兩個類別間的相對比例是顯著的。雖然這種方法使用起來非常簡單,但很有可能被我們刪除了的數據包含著預測類的重要信息。
2.過採樣- 對於不平衡的類別,我們使用拷貝現有樣本的方法隨機增加觀測數量。理想情況下這種方法給了我們足夠的樣本數,但過採樣可能導致過擬合訓練數據。
3.合成採樣( SMOTE )-該技術要求我們用合成方法得到不平衡類別的觀測,該技術與現有的使用最近鄰分類方法很類似。問題在於當一個類別的觀測數量極度稀少時該怎麼做。比如說,我們想用圖片分類問題確定一個稀有物種,但我們可能只有一幅這個稀有物種的圖片。
儘管每種方法都有各自的優點,但沒有什麼特定的啟發式方法告訴我們什麼時候使用哪種方法。我們現在將使用深度學習特定的圖像分類問題詳細研究這個問題。
圖像分類中的不平衡類
在本節中,我們將選取一個圖像分類問題,其中存在不平衡類問題,然後我們將使用一種簡單有效的技術來解決它。
問題- 我們在 kaggle 網站上選擇「座頭鯨識別挑戰」,我們期望解決不平衡類別的挑戰(理想情況下,所分類的鯨魚數量少於未分類的鯨類,並且也有少數罕見鯨類我們有的圖像數量更少。)
來自 kaggle :「在這場比賽中,你面臨著建立一個演算法來識別圖像中的鯨魚種類的挑戰。您將分析 Happy Whale 資料庫中的超過25,000張圖像,這些數據來自研究機構和公共貢獻者。 通過您的貢獻,將會幫助打開有關全球海洋哺乳動物種群動態豐富的理解領域。」
我們來看看數據
由於這是一個多標籤圖像分類問題,我想首先檢查數據在各個類別間的分布情況。
上面的圖表表明,在4251個訓練圖片中,有超過2000個類別中只有一張圖片。還有一些類中有2-5個圖片。現在,這是一個嚴重的不平衡類問題。我們不能指望用每個類別的一張圖片對深度學習模型進行訓練(雖然有些演算法可能正是用來做這個的,例如 one-shot 分類問題,但我們現在忽略先這一點)。這也會產生一個問題,即如何劃分訓練樣本和驗證樣本。理想情況下,您會希望每個類都在訓練和驗證樣本中有所體現。
我們現在應該做什麼?
我們特別考慮了兩個選項:
選項1- 對訓練樣本進行嚴格的數據增強(我們可以做到這一點,但因為我們只需要針對特定類的數據增強,這可能無法完全達到我們的目的)。因此,我選擇了看起來很簡單的選項2。
選項2- 類似於我上面提到的過採樣選項。我僅僅使用不同的圖像增強技術將不平衡類的圖像在訓練數據中複製了15次。這受到了傑里米·霍華德(Jeremy Howard )的啟發,我猜他在一次深度學習講座(fast.ai course 課程的第1部分)里提到過這一點。
在開始選項2之前,我們先看看訓練樣本中的一些圖像。
特別的是,這些圖像都是鯨魚的尾巴。因此,識別很可能與特定的圖片方向有關。
我也注意到在數據中有很多圖像是黑白圖片或只有R / B / G通道。
根據這些觀察結果,我決定編寫下面的代碼,對訓練樣本中不平衡類的圖像進行小幅改動並保存它們:
以上代碼塊對不平衡類(數量小於10)中的每個圖像都進行如下處理:
1.將每張圖片的 R、G、B 通道分別保存為增強副本
2.保存每張圖片非銳化的增強副本
3.保存每張圖片非銳化的增強副本
在上面的代碼中可以看到,我們在這個練習中嚴格使用 pillow (一個 python 圖像庫)。
現在在每個不平衡類中都至少有了10個樣本。我們繼續進行訓練。
圖像增強 -我們簡單考慮這個問題。我們只想確保我們的模型能夠獲得鯨魚尾的詳細視圖。為此,我們將變焦圖包含到圖像增強中。
學習速率探測器 -我們決定將學習率定為0.01,正如學習速率探測器所示。
我們用 Resnet50 模型進行了很少的迭代(先凍結模型,再解凍)。發現凍結的模型對於這個問題也非常有用,因為 imagenet 中有鯨魚尾圖像。
在測試數據上表現如何?
最終我們在 kaggle 排行榜上獲得了真相。我們的提出的解決方案在本次比賽中排名34,前五的平均精確度為0.41928 :)
結論
有時,最簡單的方法是最合理的(如果你沒有更多的數據,只需稍加變化地拷貝現有的數據,假裝對模型來說這一類別的大多數觀測與它們基本類似)。它們最有效並且可以更容易和直觀地完成工作。
為什麼你需要改進訓練數據,如何改進?
※距 NIPS 2018 還有小半年,會上的各種挑戰賽已經開始啦
※談談複雜多分類問題上的一些個人理解
TAG:AI研習社 |