實現屬於自己的TensorFlow(1):計算圖與前向傳播
(點擊
上方公眾號
,可快速關注)
來源: 伯樂在線 - iPytLab
http://blog.jobbole.com/113592/
前段時間因為課題需要使用了一段時間TensorFlow,感覺這種框架很有意思,除了可以搭建複雜的神經網路,也可以優化其他自己需要的計算模型,所以一直想自己學習一下寫一個類似的圖計算框架。前幾天組會開完決定著手實現一個模仿TensorFlow介面的簡陋版本圖計算框架以學習計算圖程序的編寫以及前向傳播和反向傳播的實現。目前實現了前向傳播和反向傳播以及梯度下降優化器,並寫了個優化線性模型的例子。
代碼放在了GitHub上,取名SimpleFlow, 倉庫鏈接: https://github.com/PytLab/simpleflow
雖然前向傳播反向傳播這些原理了解起來並不是很複雜,但是真正著手寫起來才發現,裡面還是有很多細節需要學習和處理才能對實際的模型進行優化(例如Loss函數對每個計算節點矩陣求導的處理)。其中SimpleFlow的代碼並沒有考慮太多的東西比如dtype
和張量size
的檢查等,因為只是為了實現主要圖計算功能並沒有考慮任何的優化, 內部張量運算使用的Numpy的介面(畢竟是學習和練手的目的嘛)。好久時間沒更新博客了,在接下來的幾篇裡面我將把實現的過程的細節總結一下,希望可以給後面學習的童鞋做個參考。
正文
本文主要介紹計算圖以及前向傳播的實現, 主要涉及
圖的構建
以及通過對構建好的圖進行後序遍歷
然後進行前向傳播計算得到具體節點上的輸出值。先貼上一個簡單的實現效果吧:
import
simpleflow
as
sf
# Create a graph
with
sf
.
Graph
().
as_default
()
:
a
=
sf
.
constant
(
1.0
,
name
=
"a"
)
b
=
sf
.
constant
(
2.0
,
name
=
"b"
)
result
=
sf
.
add
(
a
,
b
,
name
=
"result"
)
# Create a session to compute
with
tf
.
Session
()
as
sess
:
(
sess
.
run
(
result
))
計算圖(Computational Graph)
計算圖是計算代數中的一個基礎處理方法,我們可以通過一個有向圖來表示一個給定的數學表達式,並可以根據圖的特點快速方便對表達式中的變數進行求導。而神經網路的本質就是一個多層複合函數, 因此也可以通過一個圖來表示其表達式。
本部分主要總結計算圖的實現,在計算圖這個有向圖中,每個節點代表著一種特定的運算例如求和,乘積,向量乘積,平方等等… 例如求和表達式$
f(x,y)=x+y$
使用有向圖表示為:
表達式$
f(x,y,z)=z(x+y)$
使用有向圖表示為:與TensorFlow的實現不同,為了簡化,在SimpleFlow中我並沒有定義Tensor
類來表示計算圖中節點之間的數據流動,而是
直接定義節點的類型
,其中主要定義了四種類型來表示圖中的節點:
Operation
: 操作節點主要接受一個或者兩個輸入節點然後進行簡單的操作運算,例如上圖中的加法操作和乘法操作等。
Variable
: 沒有輸入節點的節點,此節點包含的數據在運算過程中是可以變化的。
Constant
: 類似
Variable
節點,也沒有輸入節點,此節點中的數據在圖的運算過程中不會發生變化
Placeholder
: 同樣沒有輸入節點,此節點的數據是通過圖建立好以後通過用戶傳入的
其實圖中的所有節點都可以看成是某種操作,其中
Variable
,
Constant
,
Placeholder
都是一種特殊的操作,只是相對於普通的
Operation
而言,他們沒有輸入,但是都會有輸出(像上圖中的
xx
,yy
節點,他們本身輸出自身的值到++
節點中去),通常會輸出到Operation
節點,進行進一步的計算。
下面我們主要介紹如何實現計算圖的基本組件: 節點和邊。
Operation
節點
節點表示操作,邊代表節點接收和輸出的數據,操作節點需要含有以下屬性:
input_nodes
: 輸入節點,裡面存放與當前節點相連接的輸入節點的引用
output_nodes
: 輸出節點, 存放以當前節點作為輸入的節點,也就是當前節點的去向
output_value
: 存儲當前節點的數值, 如果是
Add
節點,此變數就存儲兩個輸入節點
output_value
的和
name
: 當前節點的名稱
graph
: 此節點所屬的圖
下面我們定義了
Operation
基類用於表示圖中的操作節點(詳見
https://github.com/PytLab/simpleflow/blob/master/simpleflow/operations.py
):
class
Operation
(
object
)
:
""" Base class for all operations in simpleflow.
An operation is a node in computational graph receiving zero or more nodes
as input and produce zero or more nodes as output. Vertices could be an
operation, variable or placeholder.
"""
def
__init__
(
self
,
*
input_nodes
,
name
=
None
)
:
""" Operation constructor.
:param input_nodes: Input nodes for the operation node.
:type input_nodes: Objects of `Operation`, `Variable` or `Placeholder`.
:param name: The operation name.
:type name: str.
"""
# Nodes received by this operation.
self
.
input_nodes
=
input_nodes
# Nodes that receive this operation node as input.
self
.
output_nodes
=
[]
# Output value of this operation in session execution.
self
.
output_value
=
None
# Operation name.
self
.
name
=
name
# Graph the operation belongs to.
self
.
graph
=
DEFAULT_GRAPH
# Add this operation node to destination lists in its input nodes.
for
node
in
input_nodes
:
node
.
output_nodes
.
append
(
self
)
# Add this operation to default graph.
self
.
graph
.
operations
.
append
(
self
)
def
compute_output
(
self
)
:
""" Compute and return the output value of the operation.
"""
raise
NotImplementedError
def
compute_gradient
(
self
,
grad
=
None
)
:
""" Compute and return the gradient of the operation wrt inputs.
"""
raise
NotImplementedError
在初始化方法中除了定義上面提到的屬性外,還需要進行兩個操作:
將當前節點的引用添加到他輸入節點的
output_nodes
這樣可以在輸入節點中找到當前節點。
將當前節點的引用添加到圖中,方便後面對圖中的資源進行回收等操作
另外,每個操作節點還有兩個必須的方法:
comput_output
和
compute_gradient
. 他們分別負責根據輸入節點的值計算當前節點的輸出值和根據操作屬性和當前節點的值計算梯度。關於梯度的計算將在後續的文章中詳細介紹,本文只對節點輸出值的計算進行介紹。
下面我以
求和
操作為例來說明具體操作節點的實現:
class
Add
(
Operation
)
:
""" An addition operation.
"""
def
__init__
(
self
,
x
,
y
,
name
=
None
)
:
""" Addition constructor.
:param x: The first input node.
:type x: Object of `Operation`, `Variable` or `Placeholder`.
:param y: The second input node.
:type y: Object of `Operation`, `Variable` or `Placeholder`.
:param name: The operation name.
:type name: str.
"""
super
(
self
.
__class__
,
self
).
__init__
(
x
,
y
,
name
=
name
)
def
compute_output
(
self
)
:
""" Compute and return the value of addition operation.
"""
x
,
y
=
self
.
input_nodes
self
.
output_value
=
np
.
add
(
x
.
output_value
,
y
.
output_value
)
return
self
.
output_value
可見,計算當前節點
output_value
的值的
前提條件
就是他的輸入節點的值在此之前已經計算得到了
。Variable
節點
與
Operation
節點類似,
Variable
節點也需要
output_value
,
output_nodes
等屬性,但是它沒有輸入節點,也就沒有
input_nodes
屬性了,而是需要在創建的時候確定一個初始值
initial_value
:
class
Variable
(
object
)
:
""" Variable node in computational graph.
"""
def
__init__
(
self
,
initial_value
=
None
,
name
=
None
,
trainable
=
True
)
:
""" Variable constructor.
:param initial_value: The initial value of the variable.
:type initial_value: number or a ndarray.
:param name: Name of the variable.
:type name: str.
"""
# Variable initial value.
self
.
initial_value
=
initial_value
# Output value of this operation in session execution.
self
.
output_value
=
None
# Nodes that receive this variable node as input.
self
.
output_nodes
=
[]
# Variable name.
self
.
name
=
name
# Graph the variable belongs to.
self
.
graph
=
DEFAULT_GRAPH
# Add to the currently active default graph.
self
.
graph
.
variables
.
append
(
self
)
if
trainable
:
self
.
graph
.
trainable_variables
.
append
(
self
)
def
compute_output
(
self
)
:
""" Compute and return the variable value.
"""
if
self
.
output_value
is
None
:
self
.
output_value
=
self
.
initial_value
return
self
.
output_value
Constant
節點和
Placeholder
節點
Constant
和
Placeholder
節點與
Variable
節點類似,具體實現詳見:
https://github.com/PytLab/simpleflow/blob/master/simpleflow/operations.py
計算圖對象
在定義了圖中的節點後我們需要將定義好的節點放入到一個圖中統一保管,因此就需要定義一個
Graph
類來存放創建的節點,方便統一操作圖中節點的資源。
class
Graph
(
object
)
:
""" Graph containing all computing nodes.
"""
def
__init__
(
self
)
:
""" Graph constructor.
"""
self
.
operations
,
self
.
constants
,
self
.
placeholders
=
[],
[],
[]
self
.
variables
,
self
.
trainable_variables
=
[],
[]
為了提供一個默認的圖,在導入simpleflow模塊的時候創建一個全局變數來引用默認的圖:
from
.
graph
import
Graph
# Create a default graph.
import
builtins
DEFAULT_GRAPH
=
builtins
.
DEFAULT_GRAPH
=
Graph
()
為了模仿TensorFlow的介面,我們給Graph
添加上下文管理器協議方法使其成為一個上下文管理器, 同時也添加一個as_default
方法:
class
Graph
(
object
)
:
#...
def
__enter__
(
self
)
:
""" Reset default graph.
"""
global
DEFAULT_GRAPH
self
.
old_graph
=
DEFAULT_GRAPH
DEFAULT_GRAPH
=
self
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
exc_tb
)
:
""" Recover default graph.
"""
global
DEFAULT_GRAPH
DEFAULT_GRAPH
=
self
.
old_graph
def
as_default
(
self
)
:
""" Set this graph as global default graph.
"""
return
self
這樣在進入with
代碼塊之前先保存舊的默認圖對象然後將當前圖賦值給全局圖對象,這樣with
代碼塊中的節點默認會添加到當前的圖中。最後退出with
代碼塊時再對圖進行恢復即可。這樣我們可以按照TensorFlow的方式來在某個圖中創建節點.
Ok,根據上面的實現我們已經可以創建一個計算圖了:
import
simpleflow
as
sf
with
sf
.
Graph
().
as_default
()
:
a
=
sf
.
constant
([
1.0
,
2.0
],
name
=
"a"
)
b
=
sf
.
constant
(
2.0
,
name
=
"b"
)
c
=
a
*
b
前向傳播(Feedforward)
實現了計算圖和圖中的節點,我們需要對計算圖進行計算, 本部分對計算圖的前向傳播的實現進行總結。
會話
首先,我們需要實現一個
Session
來對一個已經創建好的計算圖進行計算,因為當我們創建我們之前定義的節點的時候其實只是創建了一個空節點,節點中並沒有數值可以用來計算,也就是
output_value
是空的。為了模仿TensorFlow的介面,我們在這裡也把session定義成一個上下文管理器:
class
Session
(
object
)
:
""" A session to compute a particular graph.
"""
def
__init__
(
self
)
:
""" Session constructor.
"""
# Graph the session computes for.
self
.
graph
=
DEFAULT_GRAPH
def
__enter__
(
self
)
:
""" Context management protocal method called before `with-block`.
"""
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
exc_tb
)
:
""" Context management protocal method called after `with-block`.
"""
self
.
close
()
def
close
(
self
)
:
""" Free all output values in nodes.
"""
all_nodes
=
(
self
.
graph
.
constants
+
self
.
graph
.
variables
+
self
.
graph
.
placeholders
+
self
.
graph
.
operations
+
self
.
graph
.
trainable_variables
)
for
node
in
all_nodes
:
node
.
output_value
=
None
def
run
(
self
,
operation
,
feed_dict
=
None
)
:
""" Compute the output of an operation."""
# ...
計算某個節點的輸出值
上面我們已經可以構建出一個計算圖了,計算圖中的每個節點與其相鄰的節點有方向的聯繫起來,現在我們需要根據圖中節點的關係來推算出某個節點的值。那麼如何計算呢? 還是以我們剛才¥
f(x,y,z)=z(x+y)$
的計算圖為例,若我們需要計算橙色
××
運算節點的輸出值,我們需要計算與它相連的兩個輸入節點的輸出值,進而需要計算綠色++
的輸入節點的輸出值。我們可以通過後序遍歷來獲取計算一個節點所需的所有節點的輸出值。為了方便實現,後序遍歷我直接使用了遞歸的方式來實現:
def
_get_prerequisite
(
operation
)
:
""" Perform a post-order traversal to get a list of nodes to be computed in order.
"""
postorder_nodes
=
[]
# Collection nodes recursively.
def
postorder_traverse
(
operation
)
:
if
isinstance
(
operation
,
Operation
)
:
for
input_node
in
operation
.
input_nodes
:
postorder_traverse
(
input_node
)
postorder_nodes
.
append
(
operation
)
postorder_traverse
(
operation
)
return
postorder_nodes
通過此函數我們可以獲取計算一個節點值所需要所有節點列表,再依次計算列表中節點的輸出值,最後便可以輕易的計算出當前節點的輸出值了。
class
Session
(
object
)
:
# ...
def
run
(
self
,
operation
,
feed_dict
=
None
)
:
""" Compute the output of an operation.
:param operation: A specific operation to be computed.
:type operation: object of `Operation`, `Variable` or `Placeholder`.
:param feed_dict: A mapping between placeholder and its actual value for the session.
:type feed_dict: dict.
"""
# Get all prerequisite nodes using postorder traversal.
postorder_nodes
=
_get_prerequisite
(
operation
)
for
node
in
postorder_nodes
:
if
type
(
node
)
is
Placeholder
:
node
.
output_value
=
feed_dict
[
node
]
else
:
# Operation and variable
node
.
compute_output
()
return
operation
.
output_value
例子
上面我們實現了計算圖以及前向傳播,我們就可以創建計算圖計算表達式的值了, 如下:
import
simpleflow
as
sf
# Create a graph
with
sf
.
Graph
().
as_default
()
:
w
=
sf
.
constant
([[
1
,
2
,
3
],
[
3
,
4
,
5
]],
name
=
"w"
)
x
=
sf
.
constant
([[
9
,
8
],
[
7
,
6
],
[
10
,
11
]],
name
=
"x"
)
b
=
sf
.
constant
(
1.0
,
"b"
)
result
=
sf
.
matmul
(
w
,
x
)
+
b
# Create a session to compute
with
sf
.
Session
()
as
sess
:
(
sess
.
run
(
result
))
輸出值:
array
([[
54.
,
54.
],
[
106.
,
104.
]])
總結
本文使用Python實現了計算圖以及計算圖的前向傳播,並模仿TensorFlow的介面創建了
Session
以及
Graph
對象。下篇中將繼續總結計算圖節點計算梯度的方法以及反向傳播和梯度下降優化器的實現。
最後再附上simpleflow項目的鏈接, 歡迎相互學習和交流:
https://github.com/PytLab/simpleflow
參考
Deep Learning From Scratch
https://en.wikipedia.org/wiki/Tree_traversal#Post-order
https://zhuanlan.zhihu.com/p/25496760
http://blog.csdn.net/magic_anthony/article/details/77531552#0-tsina-1-98885-397232819ff9a47a7b7e80a40613cfe1
看完本文有收穫?請轉
發分享給更多人
關注「P
ython開發者」,提升Python技能
※90 道名企筆試和演算法題 (含答題討論)
※AI 玩跳一跳的正確姿勢,跳一跳 Auto-Jump 演算法詳解
TAG:Python開發者 |