iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 26
0
AI & Data

深度學習裡的冰與火之歌 : Tensorflow vs PyTorch系列 第 26

Day 26: Tensorflow 客製化一個 C++ 運算元

為何需要客製化一個 C++ 運算元?理由通常有下面三點

  1. 現有的運算元透過繼承或擁有,都無法達到你的目標
  2. 現有的運算元透過繼承或擁有可以完成你的目標,但效率不佳
  3. 需要手動將多個已組合的運算元做融合,因為編譯器的自動融合效率可能不佳

除了最後一點是屬於編譯器的範圍,前兩點都屬於程式設計者的範疇,為了能夠實現嶄新的想法,在研究中得到驗證,總是要能重頭寫出一個全新的運算元,如 Capsule Networks 的 routing。

為了能夠將完成的 C++ 運算元應用到 Tensorflow 的框架,我們必須完成下列步驟:

  1. 註冊 C++ 運算元的 speciation。註冊運算元主要是提供一個程式介面給運算元,包括了運算元呼叫的函示名稱,輸入和輸出。同時,也涵蓋張量的維度資訊,以供張量維度推測用。
  2. 使用 C++ 實現運算元的程式邏輯,運算元的 C++ 實現可以被稱為 kernel。一個新的運算元可以有好幾種不同的 kernels,取決於輸入或輸出的型別或計算架構,如在 CPU 或 GPU。
  3. (非必要)為你的 C++ 運算元建立一個 python 的 wrapper ,若你不喜歡註冊運算元所給的預設 wrapper。這個 wrapper 會提供 C++ 運算元一個 python 介面,讓 python 使用者可以呼叫。
  4. (非必要)寫一個函式來計算你的梯度
  5. 測試你的運算元:測試可以在 python 完成,也可以在 C++ 完成。如果你寫了計算梯度的函示,並需要在 python 內測試,則可以使用 tf.test.compute_gradient_error 來診斷你的梯度計算函示。

在今天的文章中,我們將會跟隨官方文件建立一個簡單但基本的運算元,稱為 Zero Out。這個 Zero Out 運算元的輸入是一個型別為 int32 的張量,而輸出則會先拷貝輸入,並設所有的元素為零,除了第一個元素。
另外,在官方文件中還有許多 Advanced 特徵可以放進客製化的運算元中,今天都不會涵蓋。那麼,現在讓我們開始:

註冊運算元

在註冊的部分是利用 REGISTER_OP macro 來完成。在 SetShapeFn 的地方,也就是設立用於 Shape Inference 的函示,則使用了 C++ lambda,傳入一個 InferenceContext* 物件指摽,該物件會取出輸入,並將輸出的維度設成與輸入相同。

//zero_out.cc
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

//定義運算元的介面
REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 
      c->set_output(0, c->input(0));
      return Status::OK();
    });

實現 kernel

在這個步驟,我們將會完成以下事項:

  1. 建立一個新的類別,該類別會繼承 OpKernel
  2. 覆寫 compute 方法:compute 方法只有一個引數,那就是傳入一個 OpKernelContext* 物件指標,可以藉由這個物件指標去提取其他有用的資訊。

原始碼如下,在原檔案中增加即可

//zero_out.cc
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};

編譯運算元為 shared library

在這裡我們沿襲上一篇編譯的工具,也就是 shared library,將檔案放在原始碼的 source tree,並透過編譯整個 source tree 來完成編譯。編譯的工具一樣,使用 bazel,那麼相對應的 BUILD 內容則在下方:

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    name = "zero_out.so",
    srcs = ["zero_out.cc"],
)

最後則呼叫 bazel 命令來完成編譯。附檔名為 so 適用於 Linux,而 dylib 適用於 Linux。

$ bazel build --config opt //tensorflow/core/user_ops:zero_out.so

在 Python 內使用完成客製化的運算元

編譯完成後,我們可以在 python 直譯器中直接呼叫剛剛編寫的運算元。

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
with tf.Session(''):
  zero_out_module.zero_out([[1, 2], [3, 4]]).eval()

# Prints
array([[1, 0], [0, 0]], dtype=int32)

上一篇
Day 25: Tensorflow C++ front-end API
下一篇
Day 27: 再造訪 ONNX 和它的 Python API
系列文
深度學習裡的冰與火之歌 : Tensorflow vs PyTorch31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言