在ONNX 的 python API 文章中提到了針對演算法的不同,ONNX 提供了兩個不同的運算元集:定義類神經網路相關的運算元被稱為 ONNX,而與機械學習相關的被稱為 ONNX-ML。我們已經利用 super-resolution 模型做轉換的 demo, 今天我們會專注在 ONNX-ML (或 scikit-learn)和在 ONNX 不同版本轉換。
第一個例子會使用 sklearn-onnx package 來轉換 scikit-learn 模型為 ONNX 模型。
為了讓例子能夠順利的執行,除了 scikit-learn 你需要安裝 skl2onnx 和 onnxruntime,使用 pip install
命令安裝即可。 版本如下:
import sklearn
import skl2onnx
import onnxruntime
print('sklearn', sklearn.__version__)
print('skl2onnx', skl2onnx.__version__)
print('onnxruntime', onnxruntime.__version__)
# sklearn 0.21.2
# skl2onnx 1.5.2
# onnxruntime 0.5.0
首先我們得來建立一個 scikit-learn 模型。這個模型用的是 RandomForest 演算法建構的分類器,關於 skl2onnx 的 scikit-learn 模型,則可以見 skl2onnx的官方文件。資料則是 iris 的資料,主要是使用一些花的特徵如尺寸等作為分類特徵,屬於小量的資料集。
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
iris = load_iris()
X, y = iris.data, iris.target
# 將資料集分為訓練和測試
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)
# => RandomForestClassifier(bootstrap=True, class_weight=None,
# criterion='gini', max_depth=None,
# max_features='auto', max_leaf_nodes=None,
# min_impurity_decrease=0.0, min_impurity_split=None,
# min_samples_leaf=1, min_samples_split=2,
# min_weight_fraction_leaf=0.0, n_estimators=10,
# n_jobs=None, oob_score=False, random_state=None,
# verbose=0, warm_start=False)
接著我們就要使用 skl2onnx 裡的 ONNXMLTools 來將模型轉成 ONNX 格式。 ONNXMLTools 是一個轉換不同機械學習的格式到 ONNX 格式的一個工具。目前可轉換的的機械學習架構包括了: Keras, Tensorflow, Core ML(Apple),scikit-learn, Spark ML (還在實驗階段),LightGBM, libsvm 以及 XGBoost。除了 scikit-learn 的 skl2onnx , Keras 和 Tensorflow 都有提供一個 ONNXMLTools 的 wrapper,分別是 keras2onnx 和 tf2onnx。
為了能從 scikit-learn 模型轉換為 ONNX 格式,我們必須呼叫 skl2onnx.convert_sklearn
函式,根據官方API,該函式的第一個引數必須是一個 scikit-learn 模型,在本例子中為 RandomForestClassifier。在例子中還傳入了關鍵字引數 initial_types
,這個參數接受一個 python list,每一個元素都是一個 tuple,這個 tuple 必須包括變數名稱(字串)和任何定義在 data_types.py
的資料型態。
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
# 建立一個變數名為'float_input' 的物件,物件資料型態為 FloatTensorType
# 變數的值不需要傳入,但是必須傳入輸入的維度(batch 的維度以 1 代替)
initial_type = [('float_input', FloatTensorType([1, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
with open("rf_iris.onnx", "wb") as f:
f.write(onx.SerializeToString())
最後,我們則使用 ONNX Runtime 來執行模型。和 Tensorflow 很相像,ONNX Runtime 會建立一個 InferenceSession
Session 物件,並從剛才我們利用物件的 SerializeToString()
方法儲存的 rf_iris.onnx 檔案,載入模型。關於 ONNX Runtime 更詳細的說明擇留到下一篇文章。目前只用 ONNX Runtime 來測試轉換的結果。
import onnxruntime as rt
import numpy
sess = rt.InferenceSession("rf_iris.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
skl2onnx.convert_sklearn
還有一個關鍵字引數,沒有在上面的例子中秀出,那就是 target_opset
。target_opset
可以接受欲檢查的 Opset version,若沒給予skl2onnx.convert_sklearn
則會使用讀入的版本來做檢查。
但倒底什麼是 Opset version,就讓我們從頭講起...
ONNX 格式總共有三種 versioning 的方式:分別是 IR version,Opset version (運算元集合的版本)和 Model version。這三個 versioning 的架構是互相獨立,若以開發的速率來檢查三個 versioning 的改變速率,Opset 會快於 IR versioning。要查詢使用 IR 和 Opset 的版本,一方面可以印出 ModelProto
物件,其 ir_version 欄位會定義 IR version,而 opset_import 欄位會定義 Opset version。Model version 雖然沒有列出,但可以透過 ModelProto.model_version
讀取 Model version。大家可以回顧一下再造訪 ONNX Graph 這篇文章。
Model version 遇到改變模型輸入和輸出的情況,而造成原本行為無法執行時,則需要增加版本。同樣的,IR version 遇到 ONNX 的 speciation 改變造成 protobuf 的定義改變的情況,而造成原本函式庫讀入寫入模型行為無法執行,則需要增加版本。而任何修改運算元的行為造成原本運算元的行為改變,都需要增加 Opset 版本。
因為 IR Versioning 通常比 Opset versioning 更為 stable,所以下面的例子中,只會提到 Opset versioning。除此之外做格式轉換的 package,都會將檢查 Opset versioning 列為特別檢查項目。
Opset versioning 其實是指一組運算元的版本架構,每一個運算元都有自己的版本。當運算元進入或移除原有的集合都會造成 OpSet version 增加,不過也有可能每一個運算元在不同的 Opset version 仍維持相同的版本。如果要用程式的方式來提取 OpSet version,則可以呼叫 model_def.opset_import[0].version
。其中 model_def.opset_import
會傳回一個 python list,裡面只有一個元素為 onnx.OperatorSetIdProto
物件,具有 version
屬性。ONNXMLTools 的 converter,通常都會逐一檢查每一個運算元的版本,一但給定了所有運算元版本後再行計算最佳的 Opset version。
一個運算元,基本上可以藉由 domain, op_type, 和 op_version 三個欄位構成獨一無二的運算元 ID。domain 指的是模型的命名空間,通常以實現該運算元的架構 reverse-DNS 網域名稱來命名,如 'com.facebook' 是 PyTorch 一個 domain name。若由 ONNX 建立的計算圖,其 domain name 為 org.onnx。若建立模型時,domain 是空字串,則 domain 也為預設為 onnx。
若有重度使用 scikit-learn 的讀者有需要為自己客製化的 Estimator 或 Transformer 建立 ONNX converter,則可以參考 sklearn-onnx的Write your own converter for your own model文章,礙於篇幅,此處不再介紹。