如何利用TensorFlow.js部署簡單的AI版「你畫我猜」圖像識別應用
選自Medium
作者:Zaid Alyafeai
機器之心編譯
參與:Geek AI、路
本文創建了一個簡單的工具來識別手繪圖像,並且輸出當前圖像的名稱。該應用無需安裝任何額外的插件,可直接在瀏覽器上運行。作者使用谷歌 Colab 來訓練模型,並使用 TensorFlow.js 將它部署到瀏覽器上。
代碼和 demo
demo 地址:https://zaidalyafeai.github.io/sketcher/
代碼地址:https://github.com/zaidalyafeai/zaidalyafeai.github.io/tree/master/sketcher
請通過以下鏈接在谷歌 Colab 上測試自己的 notebook:https://colab.research.google.com/github/zaidalyafeai/zaidalyafeai.github.io/blob/master/sketcher/Sketcher.ipynb
數據集
我們將使用卷積神經網路(CNN)來識別不同類型的手繪圖像。這個卷積神經網路將在 Quick Draw 數據集(https://github.com/googlecreativelab/quickdraw-dataset)上接受訓練。該數據集包含 345 個類別的大約 5 千萬張手繪圖像。
部分圖像類別
流程
我們將使用 Keras 框架在谷歌 Colab 免費提供的 GPU 上訓練模型,然後使用 TensorFlow.js 直接在瀏覽器上運行模型。我在 TensorFlow.js 上創建了一個教程(https://medium.com/tensorflow/a-gentle-introduction-to-tensorflow-js-dba2e5257702)。在繼續下面的工作之前,請務必先閱讀一下這個教程。下圖為該項目的處理流程:
流程
在 Colab 上進行訓練
谷歌 Colab 為我們提供了免費的 GPU 處理能力。你可以閱讀下面的教程(https://medium.com/deep-learning-turkey/google-colab-free-gpu-tutorial-e113627b9f5d)了解如何創建 notebook 和開始進行 GPU 編程。
導入
我們將使用以 TensorFlow 作為後端、Keras 作為前端的編程框架
載入數據
由於內存容量有限,我們不會使用所有類別的圖像進行訓練。我們僅使用數據集中的 100 個類別(https://raw.githubusercontent.com/zaidalyafeai/zaidalyafeai.github.io/master/sketcher/mini_classes.txt)。每個類別的數據可以在谷歌 Colab(https://console.cloud.google.com/storage/browser/quickdrawdataset/full/numpybitmap?pli=1)上以 NumPy 數組的形式獲得,數組的大小為 [N, 784],其中 N 為某類圖像的數量。我們首先下載這個數據集:
由於內存限制,我們在這裡將每類圖像僅僅載入 5000 張。我們還將留出其中的 20% 作為測試數據。
數據預處理
我們對數據進行預處理操作,為訓練模型做準備。該模型將使用規模為 [N, 28, 28, 1] 的批處理,並且輸出規模為 [N, 100] 的概率。
創建模型
我們將創建一個簡單的卷積神經網路。請注意,模型越簡單、參數越少越好。實際上,我們將把模型轉換到瀏覽器上然後再運行,並希望模型能在預測任務中快速運行。下面的模型包含 3 個卷積層和 2 個全連接層:
擬合、驗證及測試
在這之後我們對模型進行了 5 輪訓練,將訓練數據分成了 256 批輸入模型,並且分離出 10% 作為驗證集。
訓練結果如下圖所示:
測試準確率達到了 92.20% 的 top 5 準確率。
準備 WEB 格式的模型
在我們得到滿意的模型準確率後,我們將模型保存下來,以便進行下一步的轉換。
為轉換安裝 tensorflow.js:
接著我們對模型進行轉換:
這個步驟將創建一些權重文件和包含模型架構的 json 文件。
通過 zip 將模型進行壓縮,以便將其下載到本地機器上:
最後下載模型:
在瀏覽器上進行推斷
本節中,我們將展示如何載入模型並且進行推斷。假設我們有一個尺寸為 300*300 的畫布。在這裡,我們不會詳細介紹函數介面,而是將重點放在 TensorFlow.js 的部分。
載入模型
為了使用 TensorFlow.js,我們首先使用下面的腳本:
你的本地機器上需要有一台運行中的伺服器來託管權重文件。你可以在 GitHub 上創建一個 apache 伺服器或者託管網頁,就像我在我的項目中所做的那樣(https://github.com/zaidalyafeai/zaidalyafeai.github.io/tree/master/sketcher)。
接著,通過下面的代碼將模型載入到瀏覽器:
關鍵字 await 的意思是等待模型被瀏覽器載入。
預處理
在進行預測前,我們需要對數據進行預處理。首先從畫布中獲取圖像數據:
文章稍後將介紹 getMinBox()。dpi 變數被用於根據屏幕像素的密度對裁剪出的畫布進行拉伸。
我們將畫布當前的圖像數據轉化為一個張量,調整大小並進行歸一化處理:
我們使用 model.predict 進行預測,這將返回一個規模為「N, 100」的概率。
我們可以使用簡單的函數找到 top 5 概率。
提升準確率
請記住,我們的模型接受的輸入數據是規模為 [N, 28, 28, 1] 的張量。我們繪圖畫布的尺寸為 300*300,這可能是兩個手繪圖像的大小,或者用戶可以在上面繪製一個小圖像。最好只裁剪包含當前手繪圖像的方框。為了做到這一點,我們通過找到左上方和右下方的點來提取圍繞圖像的最小邊界框。
用手繪圖像進行測試
下圖顯示了一些第一次繪製的圖像以及準確率最高的類別。所有的手繪圖像都是我用滑鼠畫的,用筆繪製的話應該會得到更高的準確率。
本文為機器之心編譯,轉載請聯繫本公眾號獲得授權。
------------------------------------------------
※弱監督學習下的商品識別:CVPR 2018細粒度識別挑戰賽獲勝方案簡介
※自動「腦補」3D環境!DeepMind最新Science論文生成查詢網路GQN
TAG:機器之心 |