非平衡數據集 focal loss 多類分類
本文為 AI 研習社編譯的技術博客,原標題 :
Multi-class classification with focal loss for imbalanced datasets
作者 |Chengwei Zhang
翻譯 | 汪鵬校對 | 斯蒂芬·二狗子
審核 | Pita 整理 | 立魚王
https://medium.com/swlh/multi-class-classification-with-focal-loss-for-imbalanced-datasets-c478700e65f5
焦點損失函數 Focal Loss(2017年何凱明大佬的論文)被提出用於密集物體檢測任務。它可以訓練高精度的密集物體探測器,哪怕前景和背景之間比例為1:1000(譯者註:facal loss 就是為了解決目標檢測中類別樣本比例嚴重失衡的問題)。本教程將向您展示如何在給定的高度不平衡的數據集的情況下,應用焦點損失函數來訓練一個多分類模型。
背景
讓我們首先了解類別不平衡數據集的一般的處理方法,然後再學習 focal loss 的解決方式。
在多分類問題中,類別平衡的數據集的目標標籤是均勻分布的。若某類目標的樣本相比其他類在數量上佔據極大優勢,則可以將該數據集視為不平衡的數據集。這種不平衡將導致兩個問題:
訓練效率低下,因為大多數樣本都是簡單的目標,這些樣本在訓練中提供給模型不太有用的信息;
簡單的樣本數量上的極大優勢會搞垮訓練,使模型性能退化。
一種常見的解決方案是執行某種形式的困難樣本挖掘,實現方式就是在訓練時選取困難樣本 或 使用更複雜的採樣,以及重新對樣本加權等方案。
對具體圖像分類問題,對數據增強技術方案變更,以便為樣本不足的類創建增強的數據。
焦點損失函數旨在通過降低內部加權(簡單樣本)來解決類別不平衡問題,這樣即使簡單樣本的數量很大,但它們對總損失的貢獻卻很小。也就是說,該函數側重於用困難樣本稀疏的數據集來訓練。
將 Focal Loss 應用於欺詐檢測任務
為了演示,我們將會使用 Kaggle上的欺詐檢測數據集 構建一個分類器,這個數據及具有極端的類不平衡問題,它包含總共6354407個正常樣本和8213個欺詐案例,兩者比例約為733:1。對這種高度不平衡的數據集的分類問題,若某模型簡單猜測所有輸入樣本為「正常」就可以達到733 /(733 1)= 99.86%的準確度,這顯然是不合理。因此,我們需要的是這個模型能夠正確檢測出欺詐案例。
為了證明focal loss 比傳統技術更有效,讓我們建立一個簡單地使用類別權重 class_weight訓練的基準模型,告訴模型「更多地關注」來自代表性不足的欺詐樣本。
基準模型
基準模型的準確率達到了99.87%,略好於通過採取「簡單路線」去猜測所有情況都為「正常」。
我們還繪製了混淆矩陣來展示模型在測試集上的分類性能。你可以看到總共有1140 480 = 1620 個樣本被錯誤分類。
混淆矩陣-基準模型
現在讓我們將focal loss應用於這個模型的訓練。你可以在下面看到如何在Keras框架下自定義焦點損失函數focal loss 。
焦點損失函數-模型
焦點損失函數focal loss 有兩個可調的參數。
焦點參數γ(gamma)平滑地調整簡單樣本被加權的速率。當γ= 0時, focal loss 效果與交叉熵函數相同,並且隨著 γ 增加,調製因子的影響同樣增加(γ = 2在實驗中表現的效果最好)。
α(alpha):平衡focal loss ,相對於非 α 平衡形式可以略微提高它的準確度。
現在讓我們把訓練好的模型與之前的模型進行比較性能。
Focal Loss 模型:
精確度:99.94%
總錯誤分類測試集樣本:766 23 = 789,將錯誤數減少了一半。
混淆矩陣-focal loss模型
結論及導讀
在這個快速教程中,我們為你的知識庫引入了一個新的工具來處理高度不平衡的數據集 — Focal Loss。並通過一個具體的例子展示了如何在Keras 的 API 中定義 focal loss進而改善你的分類模型。
你可以在我的GitHub上找到這篇文章的完整源代碼。
有關focal loss的詳細情況,可去查閱論文https://arxiv.org/abs/1708.02002。
※向頻域方向演進的卷積網路:OctConv用更低計算力做到更高準確率
※贏了世界冠軍不意外,和AI在DOTA中並肩作戰才讓人又糾結又興奮
TAG:AI研習社 |