當前位置:
首頁 > 知識 > 如何在TensorFlow中高效使用數據集

如何在TensorFlow中高效使用數據集

選自TowardsDataScience

作者:Francesco Zuppichini

機器之心編譯

處理並使用數據集是深度學習任務非常重要的組成部分。在本文中,作者 Francesco Zuppichini 將教你使用 TensorFlow 的內建管道向模型傳遞數據的方法,從此遠離「feed-dict」。本文內容已更新至最新的 TensorFlow 1.5 版本。

相關代碼地址:https://github.com/FrancescoSaverioZuppichini/Tensorflow-Dataset-Tutorial/blob/master/dataset_tutorial.ipynb

經常使用神經網路框架的人都會知道,feed-dict 是向 TensorFlow 傳遞信息最慢的方式,應該盡量避免使用。向模型提供數據的正確方式是使用輸入管道,這樣才能保證 GPU 在工作時永遠無需等待新的數據。

幸運的是,TensorFlow 擁有一個名為 Dataset 的內建 API,它可以讓我們的工作更加簡單。在本教程中,我們將介紹搭建內建管道,讓數據高效傳遞給模型的方法。

本文將解釋 Dataset 的基本原理,包含大多數常用案例。

概述

使用 Dataset 需要遵循三個步驟:

載入數據:為數據創建一個數據集實例。

創建一個迭代器:通過使用創建的數據集構建一個迭代器來對數據集進行迭代。

使用數據:通過使用創建的迭代器,我們可以找到可傳輸給模型的數據集元素。

載入數據

我們首先需要一些可以放入數據集的數據。

從 Numpy 導入

這是一種常見情況:我們擁有一個 numpy 數組,想把它傳遞給 TensorFlow。

我們當然也可以傳遞多個 numpy 數組,一個典型的例子是:當我們已有被分配多個特徵和標籤的數據時……

從張量導入

當然,我們也可以從張量中初始化自己的數據集。

從佔位符導入

當我們希望動態地修改 Dataset 中的數據時,這就會很有用,稍後會有詳述。

從生成器導入

我們還可以從生成器中初始化 Dataset,這種方式在擁有不同長度的元素的數組時有意義(例如一個序列)。

在這種情況下,你還需要告訴 Dataset 數據的類型和形狀以創建正確的張量。

創建迭代器

我們已經學會創建數據集了,但如何從中獲取數據呢?我們必須使用迭代器(Iterator),它會幫助我們遍曆數據集中的內容並找到真值。有四種類型的迭代器。

One Shot 迭代器

這是最簡單的迭代器,使用第一個示例:

隨後你需要調用 get_next() 來獲取包含這些數據的張量

我們可以運行 el 來查看它們的值。

可初始化的迭代器

如果我們想要創建一個動態的數據集,在其中可以實時更改數據源,我們可以用佔位符創建一個數據集。隨後我們可以使用通常的 feed-dict 機制來初始化佔位符。這一過程可用「可初始化迭代器(initializable iterator)」來完成。使用上一節中的第三個例子:

這次我們調用 make_initializable_iterator。然後,我們在 sess 中運行 initializer 操作,以傳遞數據,這種情況下數據是隨機的 numpy 數組。

假設現在我們有了訓練數據集和測試數據集,那麼常見的代碼如下:

然後,我們訓練該模型,並在測試數據集上對其進行測試,測試可以通過訓練後再次初始化迭代器來完成。

可重新初始化的迭代器

這個概念和之前的類似,即在數據之間動態地轉換。但並不是將新數據饋送到相同的數據集,而是在數據集之間轉換。如前,我們需要一個訓練集和一個測試集。

我們可以創建兩個數據集:

接下來是要展示的技巧,即創建一個通用的迭代器:

以及兩個初始化運算:

和之前一樣,我們得到了下一個元素:

現在,我們可以直接使用會話運行這兩個初始化運算。總結起來我們得到:

可饋送的迭代器

老實說,我並不認為這個有什麼用。基本上,它是用迭代器之間的轉換取代了數據集之間的轉換,從而得到如一個來自 make_one_shot_iterator() 的迭代器,以及一個來自 make_initializable_iterator() 的迭代器。

使用數據

在前述例子中,我們利用會話輸出 Dataset 中下一個元素的值。

為了將數據傳遞給模型,我們只需要傳遞從 get_next() 生成的張量。在下面的代碼中,我們有一個包含了兩個 numpy 數組的 Dataset,這裡用了和第一節一樣的例子。注意,我們需要將.random.sample 封裝到另一個 numpy 數組,以增加一個維度,從而將數據進行分批。

然後,和往常一樣,我們創建一個迭代器:

創建一個模型,即一個簡單的神經網路:

我們直接使用來自 iter.get_next() 的張量作為第一層的輸入和損失函數的標籤。總結起來我們得到:

輸出:

一些有用的技巧

數據分批

通常數據分批是一件令人痛苦的事情,但通過 Dataset API,我們可以利用 batch(BATCH_SIZE) 方法自動地將數據集按設定的批量大小進行分批。默認批量大小為 1。在下面的示例代碼中,我們使用的批量大小為 4。

輸出:

repeat

使用.repeat(),我們可以指定數據集被迭代的次數。如果不傳輸任何參數,循環將永久進行。通常來說,永久運行循環和在標準循環中直接控制 epoch 的數量可以得到不錯的結果。

shuffle

我們可以利用 shuffle() 進行數據集 shuffle,默認是在每一個 epoch 中將數據集 shuffle 一次。記住:數據集 shuffle 是避免過擬合的重要方法。

我們還可以設置參數 buffer_size,下一個元素將從該固定大小的緩存中均勻地選取。例如:

第一次運行的輸出:

第二次運行的輸出:

這樣,數據集 shuffle 就完成了。你還可以設置 seed 參數。

MAP

你可以使用 map 方法對數據集中的所有成員應用定製化函數。下列示例中,我們把每個元素乘 2:

輸出:

其他資源

TensorFlow 數據集教程:https://www.tensorflow.org/programmers_guide/datasets

數據集文檔:https://www.tensorflow.org/api_docs/python/tf/data/Dataset

結論

該數據集 API 使我們快速、穩健地創建優化輸入流程來訓練、評估和測試我們的模型。本文中,我們了解了很多可以常見操作。

本文為機器之心編譯,轉載請聯繫本公眾號獲得授權。

------------------------------------------------

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

阿里巴巴提出極限低比特神經網路,用於深度模型壓縮和加速
IBM Watson提出人機推理網路HuMaINs,結合人機兩者優勢

TAG:機器之心 |