當前位置:
首頁 > 新聞 > 一行代碼切換TensorFlow與PyTorch,模型訓練也能用倆框架

一行代碼切換TensorFlow與PyTorch,模型訓練也能用倆框架

機器之心報道

參與:思源


你是否有時要用 PyTorch,有時又要跑 TensorFlow?這個項目就是你需要的,你可以在訓練中同時使用兩個框架,並端到端地轉換模型。也就是說 TensorFlow 寫的計算圖可以作為某個函數,直接應用到 Torch 的張量上,這操作也是很厲害了。

在早兩天開源的 TfPyTh 中,不論是 TensorFlow 還是 PyTorch 計算圖,它們都可以包裝成一個可微函數,並在另一個框架中高效完成前向與反向傳播。

很顯然,這樣的框架交互,能節省很多重寫代碼的麻煩事。

github項目地址:BlackHC/TfPyTh

為什麼框架間的交互很重要

目前 GitHub 上有很多優質的開源模型,它們大部分都是用 PyTorch 和 TensorFlow 寫的。如果我們想要在自己的項目中調用某個開源模型,那麼它們最好都使用相同的框架,不同框架間的對接會帶來各種問題。當然要是不怕麻煩,也可以用不同的框架重寫一遍。

一行代碼切換TensorFlow與PyTorch,模型訓練也能用倆框架

以前 TensorFlow 和 PyTorch 經常會用來對比,討論哪個才是更好的深度學習框架。但是它們之間就不能友好相處么,模型在兩者之間的相互遷移應該能帶來更多的便利。

在此之前,Facebook 和微軟就嘗試過另一種方式,即神經網路交換格式 ONNX。直觀而言,該工具定義了一種通用的計算圖,不同深度學習框架構建的計算圖都能轉化為它。雖然目前 ONNX 已經原生支持 MXNet、PyTorch 和 Caffe2 等大多數框架,但是像 TensorFlow 或 Keras 之類的只能通過第三方轉換器轉換為 ONNX 格式。

而且比較重要的一點是,現階段 ONNX 只支持推理,導入的模型都需要在原框架完成訓練。所以,想要加入其它框架的模型,還是得手動轉寫成相同框架,再執行訓練。

神奇的轉換庫 TfPyTh

既然 ONNX 無法解決訓練問題,那麼就輪到 TfPyTh 這類項目出場了,它無需改寫已有的代碼就能在框架間自由轉換。具體而言,TfPyTh 允許我們將 TensorFlow 計算圖包裝成一個可調用、可微分的簡單函數,然後 PyTorch 就能直接調用它完成計算。反過來也是同樣的,TensorFlow 也能直接調用轉換後的 PyTorch 計算圖。因為轉換後的模塊是可微的,那麼正向和反向傳播都沒什麼問題。不過項目作者也表示該項目還不太完美,開源 3 天以來會有一些小的問題。例如張量必須通過 CPU 進行複製與路由,直到 TensorFlow 支持__cuda_array_interface 相關功能才能解決。

目前 TfPyTh 主要支持三大方法:

  • torch_from_tensorflow:創建一個 PyTorch 可微函數,並給定 TensorFlow 佔位符輸入計算張量輸出;
  • eager_tensorflow_from_torch:從 PyTorch 創建一個 Eager TensorFlow 函數;
  • tensorflow_from_torch:從 PyTorch 創建一個 TensorFlow 運運算元或張量。

TfPyTh 示例

如下所示為 torch_from_tensorflow 的使用案例,我們會用 TensorFlow 創建一個簡單的靜態計算圖,然後傳入 PyTorch 張量進行計算。

import tensorflow as tf
import torch as th
import numpy as np
import tfpyth
session = tf.Session()
def get_torch_function():
a = tf.placeholder(tf.float32, name="a")
b = tf.placeholder(tf.float32, name="b")
c = 3 * a + 4 * b * b
f = tfpyth.torch_from_tensorflow(session, [a, b], c).apply
return f
f = get_torch_function()
a = th.tensor(1, dtype=th.float32, requires_grad=True)
b = th.tensor(3, dtype=th.float32, requires_grad=True)
x = f(a, b)
assert x == 39.
x.backward()
assert np.allclose((a.grad, b.grad), (3., 24.))

我們可以發現,基本上 TensorFlow 完成的就是一般的運算,例如設置佔位符和建立計算流程等。TF 的靜態計算圖可以通過 session 傳遞到 TfPyTh 庫中,然後就產生了一個新的可微函數。後面我們可以將該函數用於模型的某個計算部分,再進行訓練也就沒什麼問題了。

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

數千人頂會的乾貨,ICML、CVPR2019演講視頻資源在此

TAG:機器之心 |