一個提升圖像識別準確率的精妙技巧
作者:iphyer
原文://iphyer.github.io/blog/2018/04/30/matrix/
起因
今天和同學一切討論我們在做的項目,其中他特別指出了一段 Python 代碼的精妙之處。我當時沒能立刻理解,回來仔細思考了多次,終於想明白了這個問題,不由拍案叫絕。所以特地總結下這個問題給出自己的思考。
背景
像素坐標系
我們都知道,計算機圖像是由像素點組成的。簡單化的理解,忽略各種格式的區別,就是紅色,綠色,藍色三個通道的二維矩陣組合。一個通道是一個矩陣,這樣在屏幕上,同一個像素點,三種顏色的混合就組成了一幅彩色圖像。如下圖所示
人看到的是色彩斑瀾的圖像,計算機看到的其實是一個填充了數字的矩陣。
Bounding Box
在圖像識別的時候很重要的一個工作是通過 Bounding Box 把圖像的大致區域框出來。在這一步並不需要每一個像素都精確區分,而是大致框出物體的範圍。如下圖所示:
那麼問題來了,怎麼表示這個 Bounding Box?
需要畫出邊框上的每個點嗎? 顯然不需要。
那最少需要幾個?可以發現四個頂點是最特殊的,那四個都必須嗎?顯然也不需要。
最少我們只需要不是共線的兩個頂點就行了,也就是互為對角線的頂點都可以。
習慣上我們取左上的點和右下的點(從人的角度看,暫時不考慮像素坐標系的方向問題)。
圖像識別演算法
如上面的 Bounding Box 圖所示,其實圖像識別演算法就是對於任意一個輸入的圖像,考察是不是準確地預測出了 Bounding Box 的所在位置。當然為了讓演算法可以知道怎麼學習,我們往往會事先通過 Bounding Box 標記出圖像的位置,然後訓練神經網路去通過學習圖像的各種特徵預測可能出現的Bounding Box 位置。
簡單說,我們先標記一些 Bounding Box 然後我們的演算法通過學習訓練,最後實現對於一張未知圖像的 Bounding Box 的預測。
precision(準確率) 和 recall(召回率)
這裡就會有一個很自然的問題,我們怎麼知道預測的準確還是不準確呢?
通常在機器學習演算法中我們使用兩個指標來表示演算法的性能,precision(準確率) 和 recall(召回率)。簡單說,用疾病檢測舉例,如果在 10000 人的檢測樣本中已知有 500 個陽性的病人需要預測出來,現在你設計了一個演算法,預測了 400 個人是陽性,但是實際上這四百人裡面只有 300 人是真的陽性(預測對了),其中 100 個人是陰性(預測錯了)。所以你的準確率就是,基於你預測的 400 人,因為有 300 人是對的,所以準確率是 300 / 400 = 0.75。 換句話說準確率就是你預測的所有結果中有多少是對的。
召回率表示的是你預測的結果中對的部分到底覆蓋了多少目標用戶,可以看到我們的目標是預測出 500 個陽性病人,但是你預測的 400 人中只有 300 個是對的,所以你的召回率就是 300 / 500 = 0.6。簡單說,召回率就是覆蓋率,我們希望的好的演算法能夠在預測的時候覆蓋更多的目標人群。
極端情況,如果你只預測一個人,同時這個人還是真是陽性,可以看到 準確率是 1 / 1 = 100%, 但是你的召回率只有 1/ 500 = 0.002 非常低。
另外的極端情況就是,你說所有人都是陽性病人,那麼召回率就是 1。因為你把所有病人都包括了,你預測對的人是 500,同時你期望的目標人群是500, 500 / 500 = 1。換句話說,通過非常極端嚴格的篩查條件,寧可錯誤絕不放過,你成功實現了全覆蓋。但是你的準確率低得令人髮指,只有 500 / 10000 = 0.05。
所以在日常的使用中,我們往往是需要綜合兩個指標的。比較自然的指標有 F1 分數等。有興趣的讀者可以自行研究。
問題
問題的提出
我們集中到今天討論的問題上,
假設對於某個圖像,我的演算法提出了自己預測的 N 個 Bounding Box,同時知道該圖像的 M 個正確的 Bounding Box, 如何準備快速的計算 precision(準確率) 和 recall(召回率)?
下面我給出問題的預備代碼和畫圖代碼幫助大家理解這個問題:
具體的圖像如下;
同時我們做一個簡化,只考慮每個矩陣中心是不是在可以接受的誤差範圍內重合。相當於我們暫時只考慮位置不考慮大小。因為大小往往還會做後續的精細調節。大部分演算法首先預測位置,然後大小到了需要的時候再進一步的細化。
計算中心可以使用如下的函數
問題的思路
一個簡單的思路就是逐個比較。這樣的話你的演算法需要寫大量的循環非常費力。
所以,在實際的工作中,我們大量使用矩陣操作,避免循環。因為矩陣操作往往可以避免循環,同時如果你能夠使用 GPU (通常圖像識別都會在 GPU 上運行),矩陣操作本身是經過特別優化的,特別適合 GPU 運行,這可以提高速度。
但是怎麼操作?
可以看到我特地給出了,預測 Bounding Box 數量和真實 Bounding Box 數量不一致的情況,所以這個時候如果不小心,非常容易出現矩陣的維度不匹配的情況那就毀掉了所有的計算。
所以我們可以用上面的例子來幫助思考,首先求中心的坐標,這樣,原來的 N x 4 矩陣和 M x 4 矩陣就變成了 N x 2 矩陣和 M x 2 矩陣。
的
這兒不得不說下,一個不太提及的小技巧。
在 中 也可以作為矩陣一個新的維度的佔位符。用官方文檔的說法,這叫做
The newaxis object can be used in all slicing operations to create an axis of length one. :const: newaxis is an alias for "None", and "None" can be used in place of this with the same result.
也就是如果你希望你的矩陣拓展出一個新的維度,比如從一個長度為 3 的 vector 到 3 x 1 或者 1 x 3 你就可以這麼寫
輸出是
那麼二維呢,新增加一個維度,用上面的寫法,我們可以看到結果是:
輸出是
這樣在 的操作中,如果需要增加一個額外的維度,比如存儲兩個矩陣做差的結果,就可以很方便將結果存儲在新增加的這個維度,後面會更進一步解釋。
也可以參考這個stackoverflowIn numpy, what does selection by [:,None] do?
問題的解決代碼
這裡我直接先給出代碼。
下面一步步解釋下這個函數
(1) 和 就是求 Bounding Box 的中心,在我們的例子中,做完這步操作, 是 3 x 2 矩陣 而 是 2 x 2 矩陣。
其中 是
是
(2) 將 3 x 2 矩陣 擴展為 3 x 1 x 2 矩陣
(3) 是3 x 1 x 2矩陣 減去 2 x 2 矩陣, 在數學上這是沒有定義的,因為數學要求矩陣減法操作必須維度相同,但是在 中,其實是把3 x 1 x 2矩陣作為 3 個 1 x 2 矩陣分別先擴充為 2 x 2 矩陣 再和2 x 2 的 矩陣做減法。
也就是 儘力將不能做的減法操作以第一個維度(你也可以指定其他維度)做減法操作。
所以這步結束我們得到一個 3 x 2 x 2 矩陣,沿著第一個維度,其中每一個小的 2 x 2 矩陣都是預測的 Bounding Box 和真實的 Bounding Box 的中心差值。
具體的結果如下:
如果還是有一些疑問的話,請一行行調試如下的代碼,我簡單加了輸出信息,
(4) 這步就是把上面求出來的 3 x 2 x 2 矩陣 和 我們允許的誤差矩陣( )做比較,如果小於誤差計作 ,否則計作 。但是這樣結束之後我們得到一個3 x 2 x 2的布爾(Boolean)矩陣。
分別表示對於每一個預測的 Bounding Box ,它的 x 中心和 y 中心是不是在允許的誤差範圍內。我們知道如果都在範圍內,那麼就是預測對了,反之只要一個不對,那就是預測不對。
所以 其實就是沿著第三個軸(Python從0計數)對於每一個元素是不是都是 ,
那麼第三個軸是哪個軸?
其實在我們說 差值布爾矩陣是3 x 2 x 2矩陣 就說明了第三個元素是第二個 2 ,也就是對於布爾矩陣沿著每個小列做操作。正好就是對於每個x軸中心誤差和y軸中心誤差做邏輯 操作。通過這個操作之後,我們發現現在布爾矩陣變成,
現在是3 x 2的矩陣了,相當於沿著第三個軸塌縮了。
現在的結果就是,對於我們的每一個預測(一共 3 個預測)我們分別和真實值( 兩個真實值 )做對比,得到一個判斷結果分別是 還是 ,我們可以看到第二個預測 Bounding Box 和第一個真實 Bounding Box 是吻合的。結合我們的圖,會發現正好是右下角重合的兩個矩形。
(4) 主要需要理解 ,其實就是返回每一個維度的非零元素,在我們這個布爾矩陣中就是返回每個維度的 數目。
但是在這兒 和 的含義不太明確。所以下面我進行一次形式化的論證。
可以看到在得到中心坐標後,我們的問題是
N x 2 的預測矩陣和 M x 2 真實值矩陣的差是不是在允許的範圍內。
所以在擴展一個維度後,變成N x 1 x 2 的擴展預測矩陣和 M x 2 真實值矩陣做差。然後對於每一個預測 Bounding Box 的中心,一共是 N 個,在新擴展的維度上和 M x 2 真實值矩陣做差。所以每一個預測 Bounding Box 的中心 ( 1 x 2 ) 先擴展成 M x 2,就是把單一行複製為 M行,得到每行元素相同的一個 M x 2 矩陣, 再和 M x 2 真實值矩陣做差。
所以我們到的布爾矩陣是 N x M x 2 的矩陣,其中第三個軸的第一個布爾元素代表 X 方向是不是在誤差內,第二個布爾元素代表 Y 方向是不是在誤差範圍內,然後再沿著第三個軸做 操作,所以我們得到 N x M 的矩陣,其中每一行都是代表一個預測 Bounding Box 的中心 是不是和 M 個真實 Bounding Box 的中心重合。
所以 得到 中的 數目而 得到 中的 數目。所以結合上面 precision(準確率) 和 recall(召回率) 的定義,我們知道 就是 precision(準確率) 而 就是 recall(召回率)。
當然為了得到具體的數值,還需要做進一步求解,不過已經非常簡單了, 也就是最後的幾步。
總結
可以看到這個對於這樣一個簡單的問題,通過充分利用 內置函數的特性和矩陣操作的便利,特別是增加一個維度實現對於另一個矩陣的遍歷實現具體的機器學習precision(準確率) 和 recall(召回率) 計算。
這是特別值得學習的操作。
題圖:pexels,CC0 授權。
※推薦一個小而美的 Python 格式化工具
※十道 Python 面試問題陷阱
TAG:編程派 |