當前位置:
首頁 > 知識 > 在Cloud ML Engine的TPU上從頭訓練ResNet

在Cloud ML Engine的TPU上從頭訓練ResNet

選自GoogleCloud

作者Lak Lakshmanan

機器之心編譯

參與:Geek AI、王淑婷

以往的測試顯示,張量處理單元(TPU)是能夠極大加快深度學習模型訓練速度的存在。本文作者將演示如何使用谷歌雲提供的 TPU 在自己的數據集上訓練一個最先進的圖像分類模型。文中還包含了詳細的教程目錄和內容,心動的讀者不妨跟著一起動手試試?

張量處理單元(TPU)是能夠大大加快深度學習模型訓練速度的硬體加速器。在斯坦福大學進行的獨立測試中,在 TPU 上訓練的 ResNet-50 模型能夠在 ImageNet 數據集上以最快的速度(30 分鐘)達到預期的準確率。

在本文中,我將帶領讀者使用谷歌雲提供的 TPU 在自己的數據集上訓練一個最先進的圖像分類模型。並且:

無需自行編寫 TensorFlow 代碼(我已經完成了所有代碼。)

不需要安裝軟體或基礎環境(Cloud ML Engine 是無伺服器的)

你可以在雲端訓練模型,然後在任何地方部署該模型(使用 Kubeflow)

作者寫的代碼:https://github.com/tensorflow/tpu/tree/master/models/official/resnet

Cloud ML Engine:https://cloud.google.com/ml-engine/docs/tensorflow/technical-overview

Kubeflow:https://github.com/kubeflow/kubeflow

完整的代碼存放在 GitHub 的一個 notebook 中。讀者可以使用這個 notebook 或這個 codelab 中的代碼來跟進此教程。我已經在 Cloud Datalab 中測試了 notebook,並且在 Cloud Shell 中測試了 codelab。

notebook:https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/quests/tpu/flowers_resnet.ipynb

codelab:https://codelabs.developers.google.com/codelabs/tpu-resnet

Cloud Datalab:https://cloud.google.com/datalab

Cloud Shell:https://cloud.google.com/shell/

Cloud TPUv2(如上圖所示)可以加快最先進的深度學習模型的訓練

教程目錄

指向 JPEG 數據的 CSV 文件

啟用 Cloud TPU 服務賬號

複製 ResNet 代碼

[可選] 在本地嘗試數據預處理

運行數據預處理代碼

訓練模型

部署模型

用模型進行預測

1. 指向 JPEG 數據的 CSV 文件

開始之前,你需要一個裝滿圖像文件和三個逗號分隔值(CSV)文件的文件夾,這些文件提供關於圖像的元數據。

首先,你需要一個 CSV 文件,該文件包含你希望用於訓練的圖像及其標籤。CSV 文件的每一行可能如下所示:

你可以根據自己的喜好對圖片進行命名,但是它的路徑應該是實時的,並且可以在谷歌雲存儲上訪問。標籤字元串也可以是你喜歡的任何字元串,但其中不能包含逗號。數據中應該至少包含兩類圖像,並且訓練數據集應該包含足夠多的每個類別的示例。因為本文需要從頭開始做圖像分類工作,我建議每個類別至少有 1000 張圖像,總的數據集規模至少為 20,000 張圖像。如果你的圖像數量較少,可以參考遷移學習教程(它使用相同的數據格式)。

其次,你需要一個和上面一樣的 CSV,然後將其用來評估模型。我建議你將 90% 的數據用於訓練,而另外 10% 的數據用於評估。確保評估數據集包含每個類別 10% 的圖像。

最後,你需要一個包含所有唯一標籤的文件,每行一個標籤。例如:

文件中的順序非常重要。如果最終的模型預測結果為「2」,你就可以得知該圖像為玫瑰(第 0 類代表雛菊)。你可以從你用於訓練的 CSV 文件中得到類的列表:

在上面的代碼中,我僅僅從訓練 CSV 文件中提取出了第二個欄位,並且對其進行排序,在得到的輸出結果中尋找到這些值的唯一的集合。通過你最熟悉的過程創建下面三個文件:「train_set.csv」、「eval_set.csv」、「labels.txt」,將他們上傳到雲存儲中,然後你就做好訓練模型的準備工作了。

2. 複製 ResNet 代碼

讓我們從官方 TPU 樣本(https://medium.com/r/?url=https%3A%2F%2Fgithub.com%2Ftensorflow%2Ftpu)中複製 ResNet 代碼,並製作一個可提交的包。為了做到這一點,你需要從我的 GitHub 代碼倉庫(https://github.com/GoogleCloudPlatform/training-data-analyst)中複製並運行以下腳本:

上面的「1.8」指的是 TensorFlow 1.8。我推薦大家使用最新版本的 TensorFlow。

3. 啟用 Cloud TPU 服務賬號

你需要允許 TPU 服務賬號與 ML Engine(機器學習引擎)進行對話。可以使用以下腳本查詢服務賬號,並且提供訪問許可權:

4. [可選] 在本地嘗試進行數據預處理

為了確保我們包的創建工作奏效,你可以嘗試運行下面的流程將 JPEG 文件轉換為 TensorFlow 記錄:

在這裡,「/tmp/input.csv」是你輸入的訓練文件的一小部分。請檢查訓練文件和驗證文件是否已經被正確創建。

5. 運行預處理代碼

運行以下代碼將 JPEG 文件轉換為 Cloud Dataflow 中的 TFReocord。這將向許多機器分發轉換代碼,並且自動放縮它的規模:

自動放縮 TensorFlow 記錄的創建

如果你希望在更新的數據上重新訓練你的模型,只需要在新的數據上運行這整套流程,但是請確保將其寫入到一個新的輸出目錄中,以免覆蓋之前的輸出結果。

6. 訓練模型

只需將訓練任務提交到 Cloud ML Engine 上,讓結果指向你的 Dataflow 作業的輸出目錄:

(以上是代碼截圖)

代碼中加粗的行代表了你可能想進行調整的部分:

通過這一行,我們可以在啟動訓練作業之前刪除「OUTDIR」。這會讓訓練從頭重新開始。如果你有新的圖像需要訓練,並且只希望更新現有的模型,那麼不需要刪除輸出目錄。

在這裡,我們使用了 ResNet-18,它是最小的 ResNet 模型。你可以選擇 ResNet-18、34、50 等模型。(完整列表請參閱「resnet_main.py」:https://medium.com/r/?url=https%3A%2F%2Fgithub.com%2Ftensorflow%2Ftpu%2Fblob%2Fmaster%2Fmodels%2Fofficial%2Fresnet%2Fresnet_main.py)。隨著數據集規模的增大,這些數據可以支撐起越來越大的模型的訓練:較大的模型在較小的數據集上進行訓練存在過擬合的風險。因此隨著數據集大小的增加,你可以使用更大的模型。

張量處理單元(TPU)在批處理(batch)規模為 1024 左右時工作效果非常好。而我所擁有的數據集非常小,因此使用較小的批處理規模的原因。

「train_steps」變數控制著你計劃用於訓練的時間(多少輪迭代)。每次給模型輸入數量為「train_batch_size」的圖像。要想得到一個大致合理的值,你可以嘗試配置你的訓練會話(session),這樣模型至少能接收到每個圖像 10 次。在本文的例子中,我擁有 3,300 張圖像,「train_batch_size」為 128,因此,為了模型能接收到每張圖像 10 次,我需要(3300*10)/128 步或者大約 250 步。損失曲線(見下一節 TensorBoard 中的示意圖)在 250 步時並沒有停滯(收斂),所以我將該值增大到 1,000。

「steps_per_eval」變數控制了評估的頻率。進行模型評估的計算開銷是高昂的,所以你需要試著限制評估的次數。我將訓練步設為 1000,每 250 步進行一次評估,因此我將對模型進行 4 次評估。

你需要明確指定訓練圖像、評估圖像以及標籤的數量。我使用以下腳本來確定這些數字(通過改變文件名指向你的數據集):

當模型訓練完成後(這取決於訓練文件批處理規模「train_batch_size」以及訓練步「train_step」的數量),模型文件將被導出至谷歌雲存儲中。

你可以通過 TensorBoard 查看最終的模型的質量(令其指向輸出目錄):

沒有嚴重的過擬合現象——損失曲線和評估準確率大致相等

準確率確實太低了,只有 80%。如果使用更多的數據進行訓練將有助於準確率提升。

7. 部署模型

你現在可以將模型作為 web 服務部署到 Cloud ML Engine 上(或者你可以自行安裝 TensorFlow Serving,並且在其他地方運行模型):

8. 通過模型進行預測

想要使用該模型進行預測,你需要將一個通過 base-64 方式編碼的 JPEG 圖像文件的內容發送到 web 服務上。下面是創建必要的字典數據結構的 Python 代碼片段:

將代碼封裝到可以進行必要身份驗證和 HTTP 調用的模版中:

當我使用這張圖片調用該模型時,得到了預期結果(向日葵):

這是向日葵還是其它的花呢?

本文為機器之心編譯,轉載請聯繫本公眾號獲得授權。

------------------------------------------------

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

比Python快100倍,利用spaCy和Cython實現高速NLP項目
劍橋大學:156頁PPT全景展示AI過去的12個月

TAG:機器之心 |