數據載入過慢?這裡有一份TensorFlow加速指南
王小新 編譯自 Towards Data Science
量子位 出品 | 公眾號 QbitAI
機器學習演算法爛熟於心,網路結構順手拈來,但是如果數據集載入時耗費大量時間,那整個訓練時間就會大大增加。
這個問題可能困擾著很多使用大型數據集訓練的煉丹師們。最近,Francesco Zuppichini在medium上的一篇文章,從使用Dataset函數的三個步驟講起,介紹了相應的解決方法。
以下內容譯自他的文章。
看完這篇文章後,千萬不要再用默認的輸入函數feed-dict了。
本文以TensorFlow 1.5為標準函數庫。根據以往經驗,在TensorFlow中,feed-dict函數可能是最慢的一種數據載入方法,盡量少用。把數據輸入到模型的最佳方法是使用輸入流水線(input pipeline),來確保GPU無須等待新數據輸入。
幸好,TensorFlow有一個內置介面,叫做Dataset。這個介面是為了更容易地實現數據輸入,在1.3版本已被提出。這份教程將會介紹如何使用它來創建輸入流水線,高效率地將數據輸入到模型中。
本文還會解釋Dataset介面的基本原理,包括最常見的一些用法。
所有代碼可從這個網址獲取:
https://github.com/FrancescoSaverioZuppichini/Tensorflow-Dataset-Tutorial/blob/master/dataset_tutorial.ipynb
概述
使用Dataset介面,有以下三個步驟:
1. 導入數據,從某些數據創建一個數據集實例;
2. 創建迭代器iterator,即使用已有的數據集來創建一個迭代器實例,對數據集進行迭代;
3. 消耗數據,即使用所創建的迭代器,從數據集中取出元素輸入到模型。
導入數據
首先,我們需要把數據導入到數據集中,有以下幾種方式。
使用Numpy
這是常用的一個方法,把一個numpy數組輸入到tensorflow中:
我們也可以輸入多個numpy數組。典型示例就是我們將一些數據根據特徵和標籤分類。
使用Tensors
當然,我們可以用Tensor來初始化數據集:
使用Placeholder
當我們需要多次修改Dataset函數中的數據時,這個方法是最合適的,稍後會詳細介紹。
使用generator
我們也可以使用生成器generator來初始化Dataset,在處理長度不同的元素(如序列)時,這種方法很有用:
在這種情況下,你還需要指定輸入數據的類型和形狀,來得到合適的Tensor。
創建迭代器
上面已經介紹了如何創建一個數據集,但是如何拿出裡面的數據呢?這裡要使用迭代器Iterator,來遍歷整個數據集並取出數據的實際值,有以下四種類型。
One shot迭代器
這是最簡單的一種迭代器,利用上節的示例一:
接著,再調用get_next()函數來獲取下一個數據張量:
然後,運行el函數來得到輸出值:
可初始化迭代器
如果要構建一個動態數據集,在運行時要更改其中的源數據,則應該選擇佔位符placeholder來創建數據集,然後使用經典的feed-dict方法來初始化佔位符。這些可以用一個可初始化迭代器來完成,利用上節「使用Placeholder」部分的示例:
這裡調用了make_initializable_iterator函數。在這個sess範圍內,運行initializer函數來傳遞數據,這裡先以隨機數組為例。
到這裡,我們已經構建好訓練集和測試集:
接下來,讀入數據來訓練模型,並在測試數據集上進行評估,這可通過訓練後再次初始化迭代器來完成:
可重初始化迭代器
這個概念與上個類似,要在數據之間進行動態切換。但是,上面是將新數據輸入到同一個數據集中,這裡是在數據集之間切換。同樣地,我們要構建一個訓練數據集和一個測試數據集:
我們可以創建兩個數據集:
這裡是關鍵,要構建一個通用型Iterator:
然後初始化數據集:
跟上面操作一樣,得到下個元素:
現在,可以使用構建的Session來直接運行這兩個初始化操作,把這些程序組合起來:
可饋送迭代器
在我看來,這些方法可能效果不好,它們基本上不是在數據集之間切換,而是在迭代器之間切換。你可以分別用make_one_shot_iterator函數和make_initializable_iterator函數來創建兩個迭代器。
消耗數據
在前面例子中,我們使用過Session來輸出數據集中元素next的值:
為了將數據傳遞給模型,我們只要傳遞get_next函數生成的張量。
在下面代碼段中,有一個包含兩個numpy數組的數據集,這裡用了第一節的例子。請注意,我們要用.random.sample函數來包裝另一個numpy數組以滿足數據批量化的維度要求:
接著,和上面一樣,創建一個迭代器:
下面,構建一個簡單的神經網路模型:
我們直接用iter.get_next函數輸出的張量作為第一層的輸入和損失函數的標籤,整理後得到:
輸出為:
更多內容
批處理
通常來說,批處理數據是一件麻煩的事。但是可以用Dataset函數中的批處理方法batch(BATCH_SIZE),按照設定尺寸來自動批處理數據集,其中默認值為1。在下面示例中,批尺寸設置為4:
輸出:
Repeat操作
使用repeat函數,可指定數據集中的迭代次數。若沒有參數傳遞,它會一直循環。通常在持續循環後直接用一個標準循環來控制epoch大小。
Shuffle操作
我們可使用shuffle函數來打亂數據集,該函數默認在每個epoch打亂數據集。
打亂數據集,這個操作是非常重要的,可以減弱過擬合效應。
我們也可以設置參數buffer_size,這是shuffle函數的一個內置參數,下個元素將在這個緩衝區中統一選擇。下面舉例:
第一次輸出:
第二次輸出:
可以看出,數字被打亂了。當然,你也可以設置下參數seed看看效果。
Map操作
你還可以使用map方法將自定義函數應用到數據集的每個元素中。在下面示例中,我們把每個元素都乘二:
輸出:
結論
本文介紹的Dataset API給我們提供了一種快速且穩定的方法來創建最佳的輸入流水線,以更好地訓練、評估和測試網路模型。這篇文章介紹了這個API中的大部分常見操作。更多代碼可參見本文對應的jupyter-notebook。
![](https://pic.pimg.tw/zzuyanan/1488615166-1259157397.png)
![](https://pic.pimg.tw/zzuyanan/1482887990-2595557020.jpg)
※蘋果智能音箱HomePod,在「智商」測試中排名墊底
※只有音頻沒指紋,能抓對人嗎?CMU音頻分析AI說沒問題
TAG:量子位 |