如果說 Tensorflow 1.x 的問題在於缺乏一個友善以及容易除錯的介面,那麼 PyTorch 的問題則是所建立的模型無法在高效能的語言,如 C++ 中使用。於是,前者重新大幅改寫了 1.x 的 API 並準備進入了 2.0 ,後者,則是在 v1.2 的時候,介紹了 TorchScript。
TorchScript 是一個 JIT 編譯器,目的在於將 PyTorch 寫成的 python 原始碼透過 TorchScript 編譯成一個 靜態計算圖,並且能夠配置到高效能的計算環境。在編譯的過程中,可針對計算叢集的特性,而對不同 device 的使用優化,並使模型能在非 python 的高效能執行環境上執行。
nn.Module
到 ScriptModules
讓我們從 python ...
nn.Module
模型如下:class MyCell(torch.nn.Module): # 繼承 torch.nn.Module
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4) # 定義一個 submodule
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h) # 線性轉換,再使用 tanh 作為啟動函式
return new_h
my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell)
# => MyCell(
# (linear): Linear(in_features=4, out_features=4, bias=True)
#)
print(my_cell(x, h))
# => (tensor([[0.5218, 0.5271, 0.3068, 0.6263],
# [0.7692, 0.4384, 0.7650, 0.9174],
# [0.4221, 0.7420, 0.0208, 0.8135]], grad_fn=<TanhBackward>), tensor([[0.5218, #0.5271, 0.3068, 0.6263],
# [0.7692, 0.4384, 0.7650, 0.9174],
# [0.4221, 0.7420, 0.0208, 0.8135]], grad_fn=<TanhBackward>))
一般在建立一個 PyTorch nn.Module
時需要複寫兩個方法:
__init__
: 在這裡使用者可以建立 nn.Module
模組內的類別物件,如,nn.Module.Linear
等。任何在 __init__
中宣告的 nn.Module
物件會被加到 submodules 內,在列印模型時會看到這些 submodules。forward
:這個方法則是你定義正向傳播的邏輯,需要傳回該 nn.Module
最後的輸出值。traced_cell = torch.jit.trace(my_cell, (x, h)) # 建立一個 torch.jit.ScriptModule 物件
print(traced_cell)
#=>TracedModule[MyCell](
# (linear): TracedModule[Linear]()
#)
print(traced_cell(x, h))
#=> (tensor([[0.5218, 0.5271, 0.3068, 0.6263],
# [0.7692, 0.4384, 0.7650, 0.9174],
# [0.4221, 0.7420, 0.0208, 0.8135]],
# grad_fn=<DifferentiableGraphBackward>), tensor([[0.5218, 0.5271, 0.3068, #0.6263],
# [0.7692, 0.4384, 0.7650, 0.9174],
# [0.4221, 0.7420, 0.0208, 0.8135]],
# grad_fn=<DifferentiableGraphBackward>))
在這裡,呼叫 torch.jit.trace
方法,它就像是一個錄影帶一般,把 forward 運行的流程記錄下來,編譯成一個執行檔或是一個 torch.jit.ScriptModule
或 torch.jit.ScriptFunction
物件。當我們列印出 traced_cell(x, h)
的輸出時,則發現經過 jit.trace
編譯後的 Tensors,除了他們的 grad_fun 名稱改變外,計算結果是一樣的。
我們還可以透過 torch.jit.ScriptModule
物件的 graph 方法,來 inspect 回溯後所建立的靜態計算圖。而結果如下:
print(traced_cell.graph)
#=> graph(%input : Float(3, 4),
# %h : Float(3, 4),
# %14 : Tensor,
# %15 : Tensor):
# %6 : Float(4!, 4!) = aten::t(%15), scope: MyCell/Linear[linear]
# %7 : int = prim::Constant[value=1](), scope: MyCell/Linear[linear]
# %8 : int = prim::Constant[value=1](), scope: MyCell/Linear[linear]
# %9 : Float(3, 4) = aten::addmm(%14, %input, %6, %7, %8), scope: MyCell/Linear[linear]
# %10 : int = prim::Constant[value=1](), scope: MyCell
# %11 : Float(3, 4) = aten::add(%9, %h, %10), scope: MyCell
# %12 : Float(3, 4) = aten::tanh(%11), scope: MyCell
# %13 : (Float(3, 4), Float(3, 4)) = prim::TupleConstruct(%12, %12)
# return (%13)
torch.jit.ScriptModule.graph
的輸出似乎很難閱讀,所以我們們可以透過提取 code
這個物件屬性,從 python syntax 的角度來看這張圖儲存了什麼計算流程。
print(traced_cell.code)
#=> def forward(input,
# h: Tensor,
# slot0: Tensor,
# slot1: Tensor) -> Tuple[Tensor, Tensor]:
# _0 = torch.addmm(slot0, input, torch.t(slot1), beta=1, alpha=1)
# _1 = torch.tanh(torch.add(_0, h, alpha=1))
# return (_1, _1)
可以看到有關參數的部分都被抽象化,有些運算元則被最簡化,不過大致上還是可以看到他的流程,與 MyCell forward 方法中的計算邏輯是相等的。
如果我們想要包括 if-else 或 for-loop 這樣子的動態的流程控制邏輯(control-flow)呢?讓我們再更新我們的模型,讓他擁有一個動態決定的流程控制邏輯:
class MyDecisionGate(torch.nn.Module):
def __init__(self):
super(MyDecisionGate, self).__init__()
def forward(self, x):
# adding flow control
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.dg = MyDecisionGate()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
# apply MyDecisionGate below
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
my_cell = MyCell()
print(my_cell)
#=> MyCell(
# (dg): MyDecisionGate()
# (linear): Linear(in_features=4, out_features=4, bias=True)
#)
print(my_cell(x, h))
#=> (tensor([[0.4429, 0.7279, 0.8599, 0.7389],
# [0.9568, 0.2840, 0.8771, 0.8186],
# [0.5781, 0.7588, 0.7208, 0.6238]], grad_fn=<TanhBackward>), tensor([[0.4429, #0.7279, 0.8599, 0.7389],
# [0.9568, 0.2840, 0.8771, 0.8186],
# [0.5781, 0.7588, 0.7208, 0.6238]], grad_fn=<TanhBackward>))
若用 tracing jit compiler,我們從 code 的輸出發現,原本存在的 if-else 不存在了,而且還附贈了一個警告,告訴你用 tracing compiler 編譯程式碼是不智的,因為 tracing compiler 會將在 if-statement 的 tensor 直接當作布林常數,而這樣的處理方法會導致執行時容易出錯。
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)
#=> def forward(input,
# h: Tensor,
# slot0: Tensor,
# slot1: Tensor) -> Tuple[Tensor, Tensor]:
# x = torch.addmm(slot0, input, torch.t(slot1), beta=1, alpha=1)
# _0 = torch.tanh(torch.add(x, h, alpha=1))
# return (_0, _0)
#/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:3: TracerWarning: #Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
# This is separate from the ipykernel package so we can avoid doing imports until
或許我們把 MyDecisionGate 經過 tracing 編譯的結果列印出來,就更能明白警告訊息的意思。下面就是 tracing 編譯 MyCell 的 dg 物件屬性,一個 MyDecisionGate 的實例。可以看到 traced_gate.code
所對應的程式碼並沒有清楚的流程控制邏輯,反而是 Tensor 對 Tensor 的對映。
traced_gate = torch.jit.trace(my_cell.dg, (x,))
print(traced_gate.code)
#=>def forward(self,
# x: Tensor) -> Tensor:
# return x
這個時候,我們就可以用 script compiler 去保留 MyDecisionGate 的流程邏輯
scripted_gate = torch.jit.script(MyDecisionGate())
print(scripted_gate)
#=> print(scripted_gate)=> WeakScriptModuleProxy()
my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)
print(scripted_cell)
#=> WeakScriptModuleProxy(
# (dg): WeakScriptModuleProxy()
# (linear): WeakScriptModuleProxy()
#)
可以看到經過 torch.jit.script
compiler 編譯出的結果是一個 WeakScriptModuleProxy
[註一] 物件。而原始碼的部分,則可以見到保留了流程控制的部分。
print(scripted_gate.code)
#=> def forward(self,
# x: Tensor) -> Tensor:
# _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
# if _0:
# _1 = x
# else:
# _1 = torch.neg(x)
# return _1
在有些時候,我們會希望在模型中將 tracing 和 scripting 交換著使用。如果模型中較少未知的變數,或對輸入的資料為高依賴性,我們就可以先用 tracing 編譯器將處理較少,依賴輸入資料的部分,隨後在用 script 模式去編譯出動態流程控制的編譯碼。
同樣地,倘若我們需要在較低的模組有較為精細的控制,我們可以先用 script 模式來編譯 submodules,最後再用 tracing 編譯器,來做 high-level 的加速。
ScriptModules
根據 PyTorch 的官方 TorchScript 的介紹,將 pure python 寫成的 nn.Module
類別物件由 TorchScript JIT 編譯器 byte code 後,可以:
ScriptModules
,可以 export 到後端 server 能讀取的格式,使 serving 能更容易。[註一] 在 PyTorch 1.3 為 torch.jit.ScriptModule
物件,而非如 1.2 是 torch.jit.ScriptModule
的 subclass torch.jit.WeakScriptModuleProxy