iT邦幫忙

第 11 屆 iT 邦幫忙鐵人賽

DAY 8
0
AI & Data

深度學習裡的冰與火之歌 : Tensorflow vs PyTorch系列 第 8

Day 8: Tensorflow 在 2.0 很好,那麼 PyTorch 呢?

如果說 Tensorflow 1.x 的問題在於缺乏一個友善以及容易除錯的介面,那麼 PyTorch 的問題則是所建立的模型無法在高效能的語言,如 C++ 中使用。於是,前者重新大幅改寫了 1.x 的 API 並準備進入了 2.0 ,後者,則是在 v1.2 的時候,介紹了 TorchScript。

什麼是 TorchScript,它和 PyTorch 有什麼關係?

TorchScript 是一個 JIT 編譯器,目的在於將 PyTorch 寫成的 python 原始碼透過 TorchScript 編譯成一個 靜態計算圖,並且能夠配置到高效能的計算環境。在編譯的過程中,可針對計算叢集的特性,而對不同 device 的使用優化,並使模型能在非 python 的高效能執行環境上執行。

nn.ModuleScriptModules

tracing jit compiler

讓我們從 python ...

定義一個 nn.Module 模型如下:

class MyCell(torch.nn.Module):   # 繼承 torch.nn.Module
    def __init__(self):
        super(MyCell, self).__init__() 
        self.linear = torch.nn.Linear(4, 4) # 定義一個 submodule

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h) # 線性轉換,再使用 tanh 作為啟動函式
        return new_h

my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell)
# => MyCell(
#  (linear): Linear(in_features=4, out_features=4, bias=True)
#)
print(my_cell(x, h))
# => (tensor([[0.5218, 0.5271, 0.3068, 0.6263],
#        [0.7692, 0.4384, 0.7650, 0.9174],
#        [0.4221, 0.7420, 0.0208, 0.8135]], grad_fn=<TanhBackward>), tensor([[0.5218, #0.5271, 0.3068, 0.6263],
#        [0.7692, 0.4384, 0.7650, 0.9174],
#        [0.4221, 0.7420, 0.0208, 0.8135]], grad_fn=<TanhBackward>))

一般在建立一個 PyTorch nn.Module 時需要複寫兩個方法:

  1. __init__: 在這裡使用者可以建立 nn.Module 模組內的類別物件,如,nn.Module.Linear等。任何在 __init__ 中宣告的 nn.Module 物件會被加到 submodules 內,在列印模型時會看到這些 submodules。
  2. forward:這個方法則是你定義正向傳播的邏輯,需要傳回該 nn.Module 最後的輸出值。

tracing

traced_cell = torch.jit.trace(my_cell, (x, h)) # 建立一個 torch.jit.ScriptModule 物件
print(traced_cell)
#=>TracedModule[MyCell](
#  (linear): TracedModule[Linear]()
#)
print(traced_cell(x, h))
#=> (tensor([[0.5218, 0.5271, 0.3068, 0.6263],
#        [0.7692, 0.4384, 0.7650, 0.9174],
#        [0.4221, 0.7420, 0.0208, 0.8135]],
#       grad_fn=<DifferentiableGraphBackward>), tensor([[0.5218, 0.5271, 0.3068, #0.6263],
#        [0.7692, 0.4384, 0.7650, 0.9174],
#        [0.4221, 0.7420, 0.0208, 0.8135]],
#       grad_fn=<DifferentiableGraphBackward>))

在這裡,呼叫 torch.jit.trace 方法,它就像是一個錄影帶一般,把 forward 運行的流程記錄下來,編譯成一個執行檔或是一個 torch.jit.ScriptModuletorch.jit.ScriptFunction 物件。當我們列印出 traced_cell(x, h) 的輸出時,則發現經過 jit.trace 編譯後的 Tensors,除了他們的 grad_fun 名稱改變外,計算結果是一樣的。

我們還可以透過 torch.jit.ScriptModule 物件的 graph 方法,來 inspect 回溯後所建立的靜態計算圖。而結果如下:

print(traced_cell.graph)
#=> graph(%input : Float(3, 4),
#      %h : Float(3, 4),
#      %14 : Tensor,
#      %15 : Tensor):
#  %6 : Float(4!, 4!) = aten::t(%15), scope: MyCell/Linear[linear]
#  %7 : int = prim::Constant[value=1](), scope: MyCell/Linear[linear]
#  %8 : int = prim::Constant[value=1](), scope: MyCell/Linear[linear]
#  %9 : Float(3, 4) = aten::addmm(%14, %input, %6, %7, %8), scope: MyCell/Linear[linear]
#  %10 : int = prim::Constant[value=1](), scope: MyCell
#  %11 : Float(3, 4) = aten::add(%9, %h, %10), scope: MyCell
#  %12 : Float(3, 4) = aten::tanh(%11), scope: MyCell
#  %13 : (Float(3, 4), Float(3, 4)) = prim::TupleConstruct(%12, %12)
#  return (%13)

torch.jit.ScriptModule.graph 的輸出似乎很難閱讀,所以我們們可以透過提取 code 這個物件屬性,從 python syntax 的角度來看這張圖儲存了什麼計算流程。

print(traced_cell.code)
#=> def forward(input,
#    h: Tensor,
#    slot0: Tensor,
#    slot1: Tensor) -> Tuple[Tensor, Tensor]:
#  _0 = torch.addmm(slot0, input, torch.t(slot1), beta=1, alpha=1)
#  _1 = torch.tanh(torch.add(_0, h, alpha=1))
#  return (_1, _1)

可以看到有關參數的部分都被抽象化,有些運算元則被最簡化,不過大致上還是可以看到他的流程,與 MyCell forward 方法中的計算邏輯是相等的。

script compiler

如果我們想要包括 if-else 或 for-loop 這樣子的動態的流程控制邏輯(control-flow)呢?讓我們再更新我們的模型,讓他擁有一個動態決定的流程控制邏輯:

class MyDecisionGate(torch.nn.Module):
  def __init__(self):
    super(MyDecisionGate, self).__init__()
    
  def forward(self, x):
    # adding flow control
    if x.sum() > 0:
      return x
    else:
      return -x


class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        # apply MyDecisionGate below
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
#=> MyCell(
#  (dg): MyDecisionGate()
#  (linear): Linear(in_features=4, out_features=4, bias=True)
#)
print(my_cell(x, h))
#=> (tensor([[0.4429, 0.7279, 0.8599, 0.7389],
#        [0.9568, 0.2840, 0.8771, 0.8186],
#        [0.5781, 0.7588, 0.7208, 0.6238]], grad_fn=<TanhBackward>), tensor([[0.4429, #0.7279, 0.8599, 0.7389],
#        [0.9568, 0.2840, 0.8771, 0.8186],
#        [0.5781, 0.7588, 0.7208, 0.6238]], grad_fn=<TanhBackward>))

若用 tracing jit compiler,我們從 code 的輸出發現,原本存在的 if-else 不存在了,而且還附贈了一個警告,告訴你用 tracing compiler 編譯程式碼是不智的,因為 tracing compiler 會將在 if-statement 的 tensor 直接當作布林常數,而這樣的處理方法會導致執行時容易出錯。

traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)
#=> def forward(input,
#    h: Tensor,
#    slot0: Tensor,
#    slot1: Tensor) -> Tuple[Tensor, Tensor]:
#  x = torch.addmm(slot0, input, torch.t(slot1), beta=1, alpha=1)
#  _0 = torch.tanh(torch.add(x, h, alpha=1))
#  return (_0, _0)

#/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:3: TracerWarning: #Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
#  This is separate from the ipykernel package so we can avoid doing imports until

或許我們把 MyDecisionGate 經過 tracing 編譯的結果列印出來,就更能明白警告訊息的意思。下面就是 tracing 編譯 MyCell 的 dg 物件屬性,一個 MyDecisionGate 的實例。可以看到 traced_gate.code 所對應的程式碼並沒有清楚的流程控制邏輯,反而是 Tensor 對 Tensor 的對映。

traced_gate = torch.jit.trace(my_cell.dg, (x,))
print(traced_gate.code)
#=>def forward(self,
#    x: Tensor) -> Tensor:
#  return x

這個時候,我們就可以用 script compiler 去保留 MyDecisionGate 的流程邏輯

scripted_gate = torch.jit.script(MyDecisionGate())
print(scripted_gate)
#=> print(scripted_gate)=> WeakScriptModuleProxy()
my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)
print(scripted_cell)
#=> WeakScriptModuleProxy(
#  (dg): WeakScriptModuleProxy()
#  (linear): WeakScriptModuleProxy()
#)

可以看到經過 torch.jit.script compiler 編譯出的結果是一個 WeakScriptModuleProxy[註一] 物件。而原始碼的部分,則可以見到保留了流程控制的部分。

print(scripted_gate.code)
#=> def forward(self,
#    x: Tensor) -> Tensor:
#  _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
#  if _0:
#    _1 = x
#  else:
#    _1 = torch.neg(x)
#  return _1

Mixed tracing and scripting

在有些時候,我們會希望在模型中將 tracing 和 scripting 交換著使用。如果模型中較少未知的變數,或對輸入的資料為高依賴性,我們就可以先用 tracing 編譯器將處理較少,依賴輸入資料的部分,隨後在用 script 模式去編譯出動態流程控制的編譯碼。

同樣地,倘若我們需要在較低的模組有較為精細的控制,我們可以先用 script 模式來編譯 submodules,最後再用 tracing 編譯器,來做 high-level 的加速。

為何要編譯成 ScriptModules

根據 PyTorch 的官方 TorchScript 的介紹,將 pure python 寫成的 nn.Module類別物件由 TorchScript JIT 編譯器 byte code 後,可以:

  1. 在不同的 python 直譯器上執行,而不需要在 CPython 上執行。這些 python 直譯器可以包括了 PyPy 或 IronPython 等。而在這些 python 直譯器上執行已被編譯好的靜態計算圖,可以不需要處理讓 CPython 惡名昭彰的 Global Interpreter Lock(GIL),而使模型達到平行化訓練。
  2. 編譯過後的 ScriptModules,可以 export 到後端 server 能讀取的格式,使 serving 能更容易。
  3. 編譯靜態圖後可以讓針對圖中所有的計算元作優化,雖然在編譯時間較為緩慢,但整體而言可以提高執行速度。
  4. 透過編譯後的模型,不再只能由 Python 可以讀取,反而成為與許多其他的程式語言接口。

註釋:

[註一] 在 PyTorch 1.3 為 torch.jit.ScriptModule 物件,而非如 1.2 是 torch.jit.ScriptModule 的 subclass torch.jit.WeakScriptModuleProxy

Refernces:

  1. INTRODUCTION TO TORCHSCRIPT
  2. 新版PyTorch發布!新增TorchScript API,4大功能更新值得關注

上一篇
Day 7 邁向 Tensorflow 2.0 之路
下一篇
Day 9: TorchScript 實例操作
系列文
深度學習裡的冰與火之歌 : Tensorflow vs PyTorch31

尚未有邦友留言

立即登入留言