新手必讀系列:實例解讀TensorFlow
TensorFlow是Google開發的開源軟體庫,用於機器學習。它能夠在所有Linux、Windows和MacOS平台上的CPU和GPU上運行。 Tensorflow可用於設計、實施和訓練受大腦結構和功能啟發的深度學習模型。
在本文中,將逐步指導使用TensorFlow實現簡單的神經網路。著名的鳶尾花數據集在此用於訓練,然後將給定的花分類到正確的類型。
該數據集包含3種「Setosa」、「Versicolor」和「Virginica」。為了識別每種花型,有4種萼片長度、萼片寬度、花瓣長度和花瓣寬度的屬性。我們將實施一個神經網路來使用這些因素識別正確的類型。
首先,我們將使用訓練數據集來訓練我們的模型,然後將使用測試數據集來測試它的準確性。你可以從這裡下載訓練數據集,並從這裡測試數據集。
步驟1
首先,我們需要讀取.csv文件中的數據並導入它們。Pandas 可以用來輕鬆處理這個問題。
導入數據
Pandas中的read_csv()函數將讀取文件並將內容載入到指定的變數。作為函數的參數,我們需要指定文件的路徑,而名稱參數可以用來指定每個文件的列名。
步驟2
數據集中每種花都被編碼為0、1和2。我們需要使用一種熱門方法將它們編碼為[1,0,0],[0,1,0]和[0,0,1]。這將使網路的訓練和優化變得容易,因為網路的輸出也是以一種熱門格式生成的。
那麼我們需要定義訓練集的x、y和測試集的x、y
編碼類和數據分離
步驟3
它需要為輸入(X)、輸出(Y)定義佔位符,並定義網路的權重和偏差。這裡我們有4列輸入,因為數據集有4個特徵和3列輸出來映射3種類型的花。佔位符的形狀應該滿足這一點。此外,權重矩陣的形狀必須是4x3,並且偏差必須是3的矢量以將輸入映射到輸出(無隱藏層)。
定義權重,偏差和佔位符
步驟4
然後我們需要通過一個激活函數發送輸出,這裡使用了tensorflow中可用的soft-max函數。為了訓練模型,我們需要計算模型創建的輸出中有多少誤差的成本。我們在這裡計算均方誤差。然後我們可以訓練模型,使用AdamOptimizer降低成本。
訓練張量定義
步驟5
經過訓練以檢查我們的模型是否準確,我們需要將模型預測與實際結果進行比較。然後我們可以通過計算得到的正確結果來判斷模型的準確性。
計算準確度
在這裡,需要檢查由我們的模型生成的輸出是否等於實際結果(Y)。該模型將計算每種花的價值,這可以視為每種類型的概率。我們選擇最有可能的類型。 argmax函數將返回最大值的索引。請記住,結果是一種one-hot形式,這種方法很容易讓我們檢查正確性。
在此之後,我們必須開始訓練模型。在此之前,我們需要首先初始化所有全局變數,然後使用global_variables_initializer函數。
步驟6
現在我們來訓練我們的模型。張量執行必須在張量流中的一個會話內完成。因此,在培訓之前我們需要創建一個會話,並且在完成所有事情之後,我們需要關閉會話。
訓練模式
上面的代碼塊被添加來創建tensorflow會話,並且該塊中的所有內容都將具有該會話。此外,這個代碼塊能夠在一切完成時自動關閉會話。
首先,執行變數初始張量,然後模型訓練1000次。訓練時,我們需要將訓練數據集指定為X,將相應的結果指定為Y,因為訓練tensorflow期望它們執行。在傳遞Y時,已經迭代並創建了一個新數組,以確保它具有與上面定義的形狀相同的形狀。在每次迭代中,跟蹤成本以繪製圖形以查看實際的訓練。
最後一步
最後,當訓練過程結束時,繪製成本變化圖並通過測試數據集測試模型的準確性。
繪製圖形和檢查準確性
經過1000次訓練迭代後,可以獲得96.67%的準確度,這確實令人印象深刻,成本變化圖通過降低每次迭代的成本顯示了該模型的重大發展。
因訓練而降低成本
這篇文章將有助於tensorflow的新手通過這個簡單的例子來理解它的概念。
TAG:AI全球動態 |