當前位置:
首頁 > 科技 > 數據載入過慢?這裡有一份TensorFlow加速指南

數據載入過慢?這裡有一份TensorFlow加速指南

王小新 編譯自 Towards Data Science

量子位 出品 | 公眾號 QbitAI

機器學習演算法爛熟於心,網路結構順手拈來,但是如果數據集載入時耗費大量時間,那整個訓練時間就會大大增加。

這個問題可能困擾著很多使用大型數據集訓練的煉丹師們。最近,Francesco Zuppichini在medium上的一篇文章,從使用Dataset函數的三個步驟講起,介紹了相應的解決方法。

以下內容譯自他的文章。

看完這篇文章後,千萬不要再用默認的輸入函數feed-dict了。

本文以TensorFlow 1.5為標準函數庫。根據以往經驗,在TensorFlow中,feed-dict函數可能是最慢的一種數據載入方法,盡量少用。把數據輸入到模型的最佳方法是使用輸入流水線(input pipeline),來確保GPU無須等待新數據輸入。

幸好,TensorFlow有一個內置介面,叫做Dataset。這個介面是為了更容易地實現數據輸入,在1.3版本已被提出。這份教程將會介紹如何使用它來創建輸入流水線,高效率地將數據輸入到模型中。

本文還會解釋Dataset介面的基本原理,包括最常見的一些用法。

所有代碼可從這個網址獲取:

https://github.com/FrancescoSaverioZuppichini/Tensorflow-Dataset-Tutorial/blob/master/dataset_tutorial.ipynb

概述

使用Dataset介面,有以下三個步驟:

1. 導入數據,從某些數據創建一個數據集實例;

2. 創建迭代器iterator,即使用已有的數據集來創建一個迭代器實例,對數據集進行迭代;

3. 消耗數據,即使用所創建的迭代器,從數據集中取出元素輸入到模型。

導入數據

首先,我們需要把數據導入到數據集中,有以下幾種方式。


使用Numpy

這是常用的一個方法,把一個numpy數組輸入到tensorflow中:

我們也可以輸入多個numpy數組。典型示例就是我們將一些數據根據特徵和標籤分類。


使用Tensors

當然,我們可以用Tensor來初始化數據集:


使用Placeholder

當我們需要多次修改Dataset函數中的數據時,這個方法是最合適的,稍後會詳細介紹。


使用generator

我們也可以使用生成器generator來初始化Dataset,在處理長度不同的元素(如序列)時,這種方法很有用:

在這種情況下,你還需要指定輸入數據的類型和形狀,來得到合適的Tensor。

創建迭代器

上面已經介紹了如何創建一個數據集,但是如何拿出裡面的數據呢?這裡要使用迭代器Iterator,來遍歷整個數據集並取出數據的實際值,有以下四種類型。


One shot迭代器

這是最簡單的一種迭代器,利用上節的示例一:

接著,再調用get_next()函數來獲取下一個數據張量:

然後,運行el函數來得到輸出值:


可初始化迭代器

如果要構建一個動態數據集,在運行時要更改其中的源數據,則應該選擇佔位符placeholder來創建數據集,然後使用經典的feed-dict方法來初始化佔位符。這些可以用一個可初始化迭代器來完成,利用上節「使用Placeholder」部分的示例:

這裡調用了make_initializable_iterator函數。在這個sess範圍內,運行initializer函數來傳遞數據,這裡先以隨機數組為例。

到這裡,我們已經構建好訓練集和測試集:

接下來,讀入數據來訓練模型,並在測試數據集上進行評估,這可通過訓練後再次初始化迭代器來完成:

可重初始化迭代器

這個概念與上個類似,要在數據之間進行動態切換。但是,上面是將新數據輸入到同一個數據集中,這裡是在數據集之間切換。同樣地,我們要構建一個訓練數據集和一個測試數據集:

我們可以創建兩個數據集:

這裡是關鍵,要構建一個通用型Iterator:

然後初始化數據集:

跟上面操作一樣,得到下個元素:

現在,可以使用構建的Session來直接運行這兩個初始化操作,把這些程序組合起來:


可饋送迭代器

在我看來,這些方法可能效果不好,它們基本上不是在數據集之間切換,而是在迭代器之間切換。你可以分別用make_one_shot_iterator函數和make_initializable_iterator函數來創建兩個迭代器。

消耗數據

在前面例子中,我們使用過Session來輸出數據集中元素next的值:

為了將數據傳遞給模型,我們只要傳遞get_next函數生成的張量。

在下面代碼段中,有一個包含兩個numpy數組的數據集,這裡用了第一節的例子。請注意,我們要用.random.sample函數來包裝另一個numpy數組以滿足數據批量化的維度要求:

接著,和上面一樣,創建一個迭代器:

下面,構建一個簡單的神經網路模型:

我們直接用iter.get_next函數輸出的張量作為第一層的輸入和損失函數的標籤,整理後得到:

輸出為:

更多內容


批處理

通常來說,批處理數據是一件麻煩的事。但是可以用Dataset函數中的批處理方法batch(BATCH_SIZE),按照設定尺寸來自動批處理數據集,其中默認值為1。在下面示例中,批尺寸設置為4:

輸出:


Repeat操作

使用repeat函數,可指定數據集中的迭代次數。若沒有參數傳遞,它會一直循環。通常在持續循環後直接用一個標準循環來控制epoch大小。


Shuffle操作

我們可使用shuffle函數來打亂數據集,該函數默認在每個epoch打亂數據集。

打亂數據集,這個操作是非常重要的,可以減弱過擬合效應。

我們也可以設置參數buffer_size,這是shuffle函數的一個內置參數,下個元素將在這個緩衝區中統一選擇。下面舉例:

第一次輸出:

第二次輸出:

可以看出,數字被打亂了。當然,你也可以設置下參數seed看看效果。


Map操作

你還可以使用map方法將自定義函數應用到數據集的每個元素中。在下面示例中,我們把每個元素都乘二:

輸出:

結論

本文介紹的Dataset API給我們提供了一種快速且穩定的方法來創建最佳的輸入流水線,以更好地訓練、評估和測試網路模型。這篇文章介紹了這個API中的大部分常見操作。更多代碼可參見本文對應的jupyter-notebook。

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 量子位 的精彩文章:

蘋果智能音箱HomePod,在「智商」測試中排名墊底
只有音頻沒指紋,能抓對人嗎?CMU音頻分析AI說沒問題

TAG:量子位 |