首個端到端神經網路,看AI在玩遊戲時注意什麼
新智元專欄
作者:Flood Sung
編輯:費欣欣
【新智元導讀】不用傳統外掛,訓練純深度學習AI來玩跳一跳,結果會如何?本文作者使用模仿學習,訓練了一個端到端的神經網路玩跳一跳,使用注意力機制(Attention)分析後發現,神經網路在玩遊戲時,也會自動捕捉棋子與盒子的重要位置信息。代碼不過100行,希望這個工作能讓大家體會到深度學習的魅力!
微信跳一跳自發布以來,迅速成為人們茶餘飯後的休閑利器,同時也演變成了各路程序員的競技場。程序員們爭先開發出各種牛逼外掛,把小遊戲玩出了新的境界。
然而,目前出來的各種外掛版本,大多採用傳統的方法來實現,比如使用傳統計算機視覺的顏色,邊緣檢測等方法來尋找棋子的位置。雖然已能夠在遊戲中取得較好分數,但是代碼複雜,需要針對不同手機設置不同的參數。
在深度學習如此火熱,AlphaZero已經征服各種棋類,Atari遊戲已經被計算機吊打的情況下,能不能用純深度神經網路來玩跳一跳呢?
答案當然是Yes。
本文中,我們就祭出微信跳一跳的終極奧義:僅使用一個端到端的神經網路,遠遠超越人類水平!
AI玩跳一跳的關鍵:從二維圖像學會三維距離的概念
在介紹端到端神經網路的實現方法前,我們首先要考慮,訓練AI來玩跳一跳,關鍵在哪裡?
UCL計算機系教授汪軍在接受新智元採訪時說,關鍵是讓AI通過觀看二維圖像,學會「三維距離」的概念。
一個用線性模型玩跳一跳的結果
使用模仿學習,把複雜程序全都裝進一個神經網路
那麼,端到端神經網路玩跳一跳是怎麼實現的呢?
估計有很多朋友的第一反應是:難道是用深度增強學習(Deep Reinforcement Learning),也就是AlphaGo的方法?
老實說,我一開始也是打算這麼乾的,但是具體分析後發現並不是那麼好做。
首先,我們獲取不了遊戲內部數據,不方便獲取Reward,即使我們要自己設定一個Reward,比如計算棋子離盒子正中心的位置,我們依然需要通過計算機視覺分析的方式實現,違背了我們純端到端神經網路實現的初衷。
其次,跳一跳這個小遊戲,玩的速度很慢,基本上一步要一秒多。而使用深度增強學習最大的問題就是需要訓練幾十萬步。這樣看來,要用深度增強學習訓練出一個能玩的網路,大概要訓練到明年吧!
因此,我們不用深度增強學習,而改用模仿學習(Imitation Learning),並且使用模仿學習中最簡單的行為克隆(Behavior Cloning)方法。簡單的說,就是收集很多好的遊戲輸入輸出數據,然後使用監督學習訓練。
本質上說,這種模仿學習的做法就是把人工編寫的複雜程序存進一個神經網路中。
好了,確定了方法論。下面就可以開始具體實施了。
具體實施:Talk is Easy,Show Me The Code
1、構造遊戲資料庫
從哪裡搞到很多的遊戲輸入輸出數據呢?別忘了,傳統方法都已經可以玩到10000分以上了,我們完全可以用傳統方法來收集數據。
在這裡,我使用加了點小AI的代碼,通過線性回歸訓練了一個簡單的線性模型來估計跳躍距離和按壓時間的關係,相對人工設定的參數會好一些。有了這個,我們就可以將每一次跳躍的屏幕截圖及按壓時間記錄下來:
其中圖像先做一下預處理,居中裁成正方形,方便之後的訓練,而按壓數據則存在Json中,一個圖片名稱對應一個按壓時間。圖片名稱很簡單,直接使用的截圖時間。
就這樣,我們讓微信跳一跳跳了n個小時,終於收集到了5000多個數據樣本。
有了資料庫,下面就是如何訓練了。
2、構造端到端神經網路模型及訓練
這裡我構造了一個5層的卷積神經網路,每一層神經網路包含一個64 通道的3x3 卷積核的卷積層,一個BatchNorm,一個ReLU及一個2x2的Max-Pooling層。具體如下圖所示:
由於輸出的按壓時間是一個單值,非常簡單,我們使用Mean Square Error來作為模型的損失函數Loss。因此,我們的模型是一個簡單的回歸模型。我們使用構建的資料庫進行訓練。在訓練之前,我們對圖像數據進行預處理,將其壓縮成224x224的RGB圖像,然後再輸入到神經網路。我們採用Adam作為優化器,學習率設定為0.001,訓練200個episode,一個episode隨機遍歷整個數據集一遍。
3、代碼
具體的代碼:https://github.com/songrotek/wechat_jump_end_to_end
這個Github僅包含所需的運行代碼,就兩個文件一個ios,一個android。使用方法非常簡單:
(1) 安裝iOS或Android開發所需的軟體及依賴,具體詳見:
https://github.com/wangshub/wechat_jump_game/wiki/Android-%E5%92%8C-iOS-%E6%93%8D%E4%BD%9C%E6%AD%A5%E9%AA%A4
(2) 安裝本代碼所需的PyTorch深度學習框架:pytorch.org
(3) 手機連接好電腦,注意iPhone需要在run_ios.py中更改WebDriverAgentRunner 運行後得到的IP。打開微信跳一跳,然後在Terminal中輸入:python run_ios.py或者python run_android.py
接下來就是見證奇蹟的時刻!
再看看代碼,也就是100行!驚不驚喜!意不意外!
神經網路在玩跳一跳的過程中思考了嗎?
上面就是端到端神經網路的實現方法,看起來過於簡單了。只玩到這不太符合我們的Geek精神。因此,我們不禁要問:神經網路在玩遊戲的過程中「思考」了嗎?有沒有像人類一樣,考慮了確定棋子和盒子的位置等問題?
為了驗證這一點,我們做了額外的實驗,構建一個帶有注意力(Attention)機制的神經網路進行訓練。我們使用一個4層的U-Net來輸出一個和圖像輸入維度一致的注意力蒙版(Attention Mask),然後將原有圖像與注意力蒙版相乘(Element-wise Product),得到帶蒙版的圖像,即僅考慮注意力區域的圖像。之後,再將帶蒙版圖像輸入到4層卷積全連接後輸出按壓時間。具體網路結構如下圖所示:
注意力蒙版每一個維度的值我們限制為[0,1],越趨於1就表示越關注,反之亦然。基於這樣的網路模型訓練後,我們就可以來看看神經網路在關注些什麼。下面是一些對應的截圖:
端到端的神經網路在玩跳一跳過程中,自動捕捉位置等關鍵信息
可以看出,神經網路一定程度上自動捕捉了棋子和盒子的位置信息,特別注意棋子上頭的高亮,這非常符合人玩遊戲的方式,也符合傳統做法的方法。
這在一定程度上說明,整個端到端神經網路內部也會自動捕捉到這些重要的位置信息!
小結
微信跳一跳的終極奧義就介紹到這了!大家肯定會驚訝於深度學習的神奇之處。老實說只看運行的代碼我也非常驚訝。但是,再看看訓練的方式似乎不過如此。這大概也是深度學習的魅力吧!
希望這個Work能給大家帶來更多歡樂!
加入社群
新智元AI技術+產業社群招募中,歡迎對AI技術+產業落地感興趣的同學,加小助手微信號:aiera2015_1入群;通過審核後我們將邀請進群,加入社群後務必修改群備註(姓名-公司-職位;專業群審核較嚴,敬請諒解)。
此外,新智元AI技術+產業領域社群(智能汽車、機器學習、深度學習、神經網路等)正在面向正在從事相關領域的工程師及研究人員進行招募。
加入新智元技術社群 共享AI+開放平台
TAG:新智元 |