在Cloud ML Engine的TPU上從頭訓練ResNet
選自GoogleCloud
作者:Lak Lakshmanan
機器之心編譯
參與:Geek AI、王淑婷
以往的測試顯示,張量處理單元(TPU)是能夠極大加快深度學習模型訓練速度的存在。本文作者將演示如何使用谷歌雲提供的 TPU 在自己的數據集上訓練一個最先進的圖像分類模型。文中還包含了詳細的教程目錄和內容,心動的讀者不妨跟著一起動手試試?
張量處理單元(TPU)是能夠大大加快深度學習模型訓練速度的硬體加速器。在斯坦福大學進行的獨立測試中,在 TPU 上訓練的 ResNet-50 模型能夠在 ImageNet 數據集上以最快的速度(30 分鐘)達到預期的準確率。
在本文中,我將帶領讀者使用谷歌雲提供的 TPU 在自己的數據集上訓練一個最先進的圖像分類模型。並且:
無需自行編寫 TensorFlow 代碼(我已經完成了所有代碼。)
不需要安裝軟體或基礎環境(Cloud ML Engine 是無伺服器的)
你可以在雲端訓練模型,然後在任何地方部署該模型(使用 Kubeflow)
作者寫的代碼:https://github.com/tensorflow/tpu/tree/master/models/official/resnet
Cloud ML Engine:https://cloud.google.com/ml-engine/docs/tensorflow/technical-overview
Kubeflow:https://github.com/kubeflow/kubeflow
完整的代碼存放在 GitHub 的一個 notebook 中。讀者可以使用這個 notebook 或這個 codelab 中的代碼來跟進此教程。我已經在 Cloud Datalab 中測試了 notebook,並且在 Cloud Shell 中測試了 codelab。
notebook:https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/quests/tpu/flowers_resnet.ipynb
codelab:https://codelabs.developers.google.com/codelabs/tpu-resnet
Cloud Datalab:https://cloud.google.com/datalab
Cloud Shell:https://cloud.google.com/shell/
Cloud TPUv2(如上圖所示)可以加快最先進的深度學習模型的訓練
教程目錄
指向 JPEG 數據的 CSV 文件
啟用 Cloud TPU 服務賬號
複製 ResNet 代碼
[可選] 在本地嘗試數據預處理
運行數據預處理代碼
訓練模型
部署模型
用模型進行預測
1. 指向 JPEG 數據的 CSV 文件
開始之前,你需要一個裝滿圖像文件和三個逗號分隔值(CSV)文件的文件夾,這些文件提供關於圖像的元數據。
首先,你需要一個 CSV 文件,該文件包含你希望用於訓練的圖像及其標籤。CSV 文件的每一行可能如下所示:
你可以根據自己的喜好對圖片進行命名,但是它的路徑應該是實時的,並且可以在谷歌雲存儲上訪問。標籤字元串也可以是你喜歡的任何字元串,但其中不能包含逗號。數據中應該至少包含兩類圖像,並且訓練數據集應該包含足夠多的每個類別的示例。因為本文需要從頭開始做圖像分類工作,我建議每個類別至少有 1000 張圖像,總的數據集規模至少為 20,000 張圖像。如果你的圖像數量較少,可以參考遷移學習教程(它使用相同的數據格式)。
其次,你需要一個和上面一樣的 CSV,然後將其用來評估模型。我建議你將 90% 的數據用於訓練,而另外 10% 的數據用於評估。確保評估數據集包含每個類別 10% 的圖像。
最後,你需要一個包含所有唯一標籤的文件,每行一個標籤。例如:
文件中的順序非常重要。如果最終的模型預測結果為「2」,你就可以得知該圖像為玫瑰(第 0 類代表雛菊)。你可以從你用於訓練的 CSV 文件中得到類的列表:
在上面的代碼中,我僅僅從訓練 CSV 文件中提取出了第二個欄位,並且對其進行排序,在得到的輸出結果中尋找到這些值的唯一的集合。通過你最熟悉的過程創建下面三個文件:「train_set.csv」、「eval_set.csv」、「labels.txt」,將他們上傳到雲存儲中,然後你就做好訓練模型的準備工作了。
2. 複製 ResNet 代碼
讓我們從官方 TPU 樣本(https://medium.com/r/?url=https%3A%2F%2Fgithub.com%2Ftensorflow%2Ftpu)中複製 ResNet 代碼,並製作一個可提交的包。為了做到這一點,你需要從我的 GitHub 代碼倉庫(https://github.com/GoogleCloudPlatform/training-data-analyst)中複製並運行以下腳本:
上面的「1.8」指的是 TensorFlow 1.8。我推薦大家使用最新版本的 TensorFlow。
3. 啟用 Cloud TPU 服務賬號
你需要允許 TPU 服務賬號與 ML Engine(機器學習引擎)進行對話。可以使用以下腳本查詢服務賬號,並且提供訪問許可權:
4. [可選] 在本地嘗試進行數據預處理
為了確保我們包的創建工作奏效,你可以嘗試運行下面的流程將 JPEG 文件轉換為 TensorFlow 記錄:
在這裡,「/tmp/input.csv」是你輸入的訓練文件的一小部分。請檢查訓練文件和驗證文件是否已經被正確創建。
5. 運行預處理代碼
運行以下代碼將 JPEG 文件轉換為 Cloud Dataflow 中的 TFReocord。這將向許多機器分發轉換代碼,並且自動放縮它的規模:
自動放縮 TensorFlow 記錄的創建
如果你希望在更新的數據上重新訓練你的模型,只需要在新的數據上運行這整套流程,但是請確保將其寫入到一個新的輸出目錄中,以免覆蓋之前的輸出結果。
6. 訓練模型
只需將訓練任務提交到 Cloud ML Engine 上,讓結果指向你的 Dataflow 作業的輸出目錄:
(以上是代碼截圖)
代碼中加粗的行代表了你可能想進行調整的部分:
通過這一行,我們可以在啟動訓練作業之前刪除「OUTDIR」。這會讓訓練從頭重新開始。如果你有新的圖像需要訓練,並且只希望更新現有的模型,那麼不需要刪除輸出目錄。
在這裡,我們使用了 ResNet-18,它是最小的 ResNet 模型。你可以選擇 ResNet-18、34、50 等模型。(完整列表請參閱「resnet_main.py」:https://medium.com/r/?url=https%3A%2F%2Fgithub.com%2Ftensorflow%2Ftpu%2Fblob%2Fmaster%2Fmodels%2Fofficial%2Fresnet%2Fresnet_main.py)。隨著數據集規模的增大,這些數據可以支撐起越來越大的模型的訓練:較大的模型在較小的數據集上進行訓練存在過擬合的風險。因此隨著數據集大小的增加,你可以使用更大的模型。
張量處理單元(TPU)在批處理(batch)規模為 1024 左右時工作效果非常好。而我所擁有的數據集非常小,因此使用較小的批處理規模的原因。
「train_steps」變數控制著你計劃用於訓練的時間(多少輪迭代)。每次給模型輸入數量為「train_batch_size」的圖像。要想得到一個大致合理的值,你可以嘗試配置你的訓練會話(session),這樣模型至少能接收到每個圖像 10 次。在本文的例子中,我擁有 3,300 張圖像,「train_batch_size」為 128,因此,為了模型能接收到每張圖像 10 次,我需要(3300*10)/128 步或者大約 250 步。損失曲線(見下一節 TensorBoard 中的示意圖)在 250 步時並沒有停滯(收斂),所以我將該值增大到 1,000。
「steps_per_eval」變數控制了評估的頻率。進行模型評估的計算開銷是高昂的,所以你需要試著限制評估的次數。我將訓練步設為 1000,每 250 步進行一次評估,因此我將對模型進行 4 次評估。
你需要明確指定訓練圖像、評估圖像以及標籤的數量。我使用以下腳本來確定這些數字(通過改變文件名指向你的數據集):
當模型訓練完成後(這取決於訓練文件批處理規模「train_batch_size」以及訓練步「train_step」的數量),模型文件將被導出至谷歌雲存儲中。
你可以通過 TensorBoard 查看最終的模型的質量(令其指向輸出目錄):
沒有嚴重的過擬合現象——損失曲線和評估準確率大致相等
準確率確實太低了,只有 80%。如果使用更多的數據進行訓練將有助於準確率提升。
7. 部署模型
你現在可以將模型作為 web 服務部署到 Cloud ML Engine 上(或者你可以自行安裝 TensorFlow Serving,並且在其他地方運行模型):
8. 通過模型進行預測
想要使用該模型進行預測,你需要將一個通過 base-64 方式編碼的 JPEG 圖像文件的內容發送到 web 服務上。下面是創建必要的字典數據結構的 Python 代碼片段:
將代碼封裝到可以進行必要身份驗證和 HTTP 調用的模版中:
當我使用這張圖片調用該模型時,得到了預期結果(向日葵):
這是向日葵還是其它的花呢?
本文為機器之心編譯,轉載請聯繫本公眾號獲得授權。
------------------------------------------------
※比Python快100倍,利用spaCy和Cython實現高速NLP項目
※劍橋大學:156頁PPT全景展示AI過去的12個月
TAG:機器之心 |