當前位置:
首頁 > 最新 > 如何使用TensorFlow中的Dataset API

如何使用TensorFlow中的Dataset API

翻譯 | AI科技大本營

參與 | zzq

審校 | reason_W

本文已更新至TensorFlow1.5版本

我們知道,在TensorFlow中可以使用feed-dict的方式輸入數據信息,但是這種方法的速度是最慢的,在實際應用中應該盡量避免這種方法。而使用輸入管道就可以保證GPU在工作時無需等待新的數據輸入,這才是正確的方法。

幸運的是,TensorFlow提供了一種內置的API——Dataset,使得我們可以很容易地就利用輸入管道的方式輸入數據。在這篇教程中,我們將介紹如何創建和使用輸入管道以及如何高效地向模型輸入數據。

這篇文章將解釋DatasetAPI的基本工作機制,並給出了幾種最常用的例子。

你可以通過下面的網站地址下載文章中的代碼:

https://github.com/FrancescoSaverioZuppichini/Tensorflow-Dataset-Tutorial/blob/master/dataset_tutorial.ipynb

▌概述

使用Dataset的三個步驟:

1. 載入數據:為數據創建一個Dataset實例

2. 創建一個迭代器:使用創建的數據集來構造一個Iterator實例以遍曆數據集

3. 使用數據:使用創建的迭代器,我們可以從數據集中獲取數據元素,從而輸入到模型中去。

▌載入數據

首先,我們需要將一些數據放到數據集中。

從numpy載入

這是最常見的情況,假設我們有一個numpy數組,我們想將它傳遞給TensorFlow

我們也可以傳遞多個numpy數組,最典型的例子是當數據被劃分為特徵和標籤的時候:

從tensors中載入

我們當然也可以用一些張量初始化數據集

從placeholder中載入

如果我們想動態地改變Dataset中的數據,使用這種方式是很有用的。

從generator載入

我們也可以從generator中初始化一個Dataset。當一個數組中元素長度不相同時,使用這種方式處理是很有效的。(例如一個序列)

在這種情況下,你還需要指定數據的類型和大小以創建正確的tensor

▌創建一個迭代器

我們已經知道了如何創建數據集,但是如何從中獲取數據呢?我們需要使用一個Iterator遍曆數據集並重新得到數據真實值。有四種形式的迭代器。

One shot Iterator

這是最簡單的迭代器,下面給出第一個例子:

接著你需要調用get_next()來獲得包含數據的張量

我們可以運行 el 來查看它們的值。

可初始化的迭代器

如果我們想建立一個可以在運行時改變數據源的動態數據集,我們可以用placeholder 創建一個數據集。接著用常見的feed-dict機制初始化這個placeholder。這些工作可以通過使用一個可初始化的迭代器完成。使用上一節的第三個例子

這次,我們調用make_initializable_iterator。接著我們在 sess 中運行 initializer 操作,以傳遞數據,這種情況下數據是隨機的 numpy 數組。

假設我們有了訓練集和測試集,如下代碼所示

接著,我們訓練該模型,並在測試數據集上對其進行測試,這可以通過訓練後對迭代器再次進行初始化來完成。

可重新初始化的迭代器

這個概念和之前的相似,我們想在數據間動態切換。但是我們是轉換數據集而不是把新數據送到相同的數據集。和之前一樣,我們需要一個訓練集和一個測試集

接下來創建兩個Dataset

現在我們要用到一個小技巧,即創建一個通用的Iterator

接著創建兩個初始化運算

和之前一樣,我們得到下一個元素

現在,我們可以直接使用session運行兩個初始化運算。把上面這些綜合起來我們可以得到:

Feedable迭代器

老實說,我並不認為這種迭代器有用。這種方式是在迭代器之間轉換而不是在數據集間轉換,比如在來自make_one_shot_iterator()的一個迭代器和來自make_initializable_iterator()的一個迭代器之間進行轉換。

▌使用數據

在之前的例子中,我們使用session來列印Dataset中next元素的值

現在為了向模型傳遞數據,我們只需要傳遞get_next()產生的張量。

在下面的代碼中,我們有一個包含兩個numpy數組的Dataset,這裡用到了和第一節一樣的例子。注意到我們需要將.random.sample封裝到另外一個numpy數組中,因此會增加一個維度以用於數據batch。

接下來和平時一樣,我們創建一個迭代器

建立一個簡單的神經網路模型

我們直接使用來自iter.get_next()的張量作為神經網路第一層的輸入和損失函數的標籤。將上面的綜合起來可以得到:

輸出:

▌有用的技巧

batch

通常情況下,batch是一件麻煩的事情,但是通過Dataset API我們可以使用batch(BATCH_SIZE)方法自動地將數據按照指定的大小batch,默認值是1。在接下來的例子中,我們使用的batch大小為4。

輸出:

Repeat

使用.repeat()我們可以指定數據集迭代的次數。如果沒有設置參數,則迭代會一直循環。通常來說,一直循環並直接用標準循環控制epoch的次數能取得較好的效果。

Shuffle

我們可以使用shuffle()方法將Dataset隨機洗牌,默認是在數據集中對每一個epoch洗牌,這種處理可以避免過擬合。

我們也可以設置buffer_size參數,下一個元素將從這個固定大小的緩存中按照均勻分布抽取。例子:

首次運行輸出:

第二次運行輸出:

這樣數據就被洗牌了。你還可以設置seed參數

▌Map

你可以使用map()方法對數據集的每個成員應用自定義的函數。在下面的例子中,我們將每個元素乘以2。

輸出:

其他資源

TensorFlow dataset tutorial: https://www.tensorflow.org/programmers_guide/datasets

Dataset docs:https://www.tensorflow.org/api_docs/python/tf/data/Dataset

▌結論

Dataset API提供了一種快速而且魯棒的方法來創建優化的輸入管道來訓練、評估和測試我們的模型。在這篇文章中,我們了解了很多常見的利用Dataset API的操作。

原文:https://towardsdatascience.com/how-to-use-dataset-in-tensorflow-c758ef9e4428

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 酷銳科技達人 的精彩文章:

地球上現存最古老的生物是什麼?

TAG:酷銳科技達人 |