iT邦幫忙

2021 iThome 鐵人賽

DAY 21
0

上一篇說明了Auto ML的基本概念, 本篇我們就來使用auto sklearn實作看看 Auto ML怎麼操作.

請到github下載 notebook

  • notebook: cardiovascular_disease_prediction_notebook_automl.ipynb
  • dateset: cardio_train.csv

我們接續上一個心血管疾病範例改寫, 原本是使用XGBoost演算法, 這次我們改用auto skleran來進行訓練.

首先我們需要安裝 scikit learn與auto sklearn

pip install -U scikit-learn auto-sklearn

然後我們import所需要的module

import autosklearn
import sklearn.metrics
import autosklearn.classification

把autosklean的版本印出來看看

print('autosklearn: %s' % autosklearn.__version__)

接下來我們要呼叫auto sklearn所提供的演算法物件, 選擇classification的AutoSklearnClassifier物件, 參數說明如下

  • time_left_for_this_task: 對每一種模型最多的訓練時間, 預設值是3600秒, 數值愈高模型的準確度會愈高, 但所需要的時間就比較長. 在我們的demo中時間設定短一點(300秒), 這樣不會等待太久, 同時也能看到執行的效果.
  • per_run_time_limit: 每次執行(run)的最高秒數, 預設值是time_left_for_this_task參數的十分之一(360秒). 在我們的demo中也設定為300秒的十分之一, 也就是30秒
# define search
model = autosklearn.classification.AutoSklearnClassifier(time_left_for_this_task=300, per_run_time_limit=30)

執行訓練. 這裡的訓練資料直接使用心血管疾病資料集所切分好的X(features)與y(target), 然後呼叫fit()就會開始執行.

# perform the search
model.fit(X_train, y_train)

在等待一段時間完成訓練之後, 印出統計資料看一下.

# summarize
print(model.sprint_statistics())

這次的統計結果如下.

  • 有11種演算法被納入評估
  • 有3種演算法成功完成訓練, 有8種演算法在限制的時間無法完成
  • 最好的演算法取得的驗證分數為 0.73, metric是 accuracy
auto-sklearn results:
  Dataset name: 77c66f3c-0fdc-11ec-8906-8b133cd412c7
  Metric: accuracy
  Best validation score: 0.732738
  Number of target algorithm runs: 11
  Number of successful target algorithm runs: 3
  Number of crashed target algorithm runs: 0
  Number of target algorithms that exceeded the time limit: 8
  Number of target algorithms that exceeded the memory limit: 0

前面的統計計表說明有3種演算法完成訓練, 那我們來看一下是哪三種演算法

print(model.leaderboard())

結果顯示最好的演算法是gradient_boosting.

          rank  ensemble_weight               type      cost   duration
model_id                                                               
4            1             0.84  gradient_boosting  0.267262  21.289428
8            2             0.08                mlp  0.354654  27.291203
10           3             0.08                sgd  0.356439   7.165989

最後可以印出每固模型的詳細資料

print(model.show_models())

這次的詳細資料如下:

[(0.840000, SimpleClassificationPipeline({'balancing:strategy': 'weighting', 'classifier:__choice__': 'gradient_boosting', 'data_preprocessing:categorical_transformer:categorical_encoding:__choice__': 'one_hot_encoding', 'data_preprocessing:categorical_transformer:category_coalescence:__choice__': 'minority_coalescer', 'data_preprocessing:numerical_transformer:imputation:strategy': 'mean', 'data_preprocessing:numerical_transformer:rescaling:__choice__': 'standardize', 'feature_preprocessor:__choice__': 'no_preprocessing', 'classifier:gradient_boosting:early_stop': 'off', 'classifier:gradient_boosting:l2_regularization': 1.0945814167023392e-10, 'classifier:gradient_boosting:learning_rate': 0.11042628136263043, 'classifier:gradient_boosting:loss': 'auto', 'classifier:gradient_boosting:max_bins': 255, 'classifier:gradient_boosting:max_depth': 'None', 'classifier:gradient_boosting:max_leaf_nodes': 30, 'classifier:gradient_boosting:min_samples_leaf': 22, 'classifier:gradient_boosting:scoring': 'loss', 'classifier:gradient_boosting:tol': 1e-07, 'data_preprocessing:categorical_transformer:category_coalescence:minority_coalescer:minimum_fraction': 0.05141281638752715},
dataset_properties={
  'task': 1,
  'sparse': False,
  'multilabel': False,
  'multiclass': False,
  'target_type': 'classification',
  'signed': False})),
(0.080000, SimpleClassificationPipeline({'balancing:strategy': 'none', 'classifier:__choice__': 'mlp', 'data_preprocessing:categorical_transformer:categorical_encoding:__choice__': 'no_encoding', 'data_preprocessing:categorical_transformer:category_coalescence:__choice__': 'minority_coalescer', 'data_preprocessing:numerical_transformer:imputation:strategy': 'mean', 'data_preprocessing:numerical_transformer:rescaling:__choice__': 'standardize', 'feature_preprocessor:__choice__': 'feature_agglomeration', 'classifier:mlp:activation': 'tanh', 'classifier:mlp:alpha': 0.05476322473700896, 'classifier:mlp:batch_size': 'auto', 'classifier:mlp:beta_1': 0.9, 'classifier:mlp:beta_2': 0.999, 'classifier:mlp:early_stopping': 'valid', 'classifier:mlp:epsilon': 1e-08, 'classifier:mlp:hidden_layer_depth': 1, 'classifier:mlp:learning_rate_init': 0.012698439797907473, 'classifier:mlp:n_iter_no_change': 32, 'classifier:mlp:num_nodes_per_layer': 136, 'classifier:mlp:shuffle': 'True', 'classifier:mlp:solver': 'adam', 'classifier:mlp:tol': 0.0001, 'data_preprocessing:categorical_transformer:category_coalescence:minority_coalescer:minimum_fraction': 0.07441872802099897, 'feature_preprocessor:feature_agglomeration:affinity': 'manhattan', 'feature_preprocessor:feature_agglomeration:linkage': 'average', 'feature_preprocessor:feature_agglomeration:n_clusters': 264, 'feature_preprocessor:feature_agglomeration:pooling_func': 'max', 'classifier:mlp:validation_fraction': 0.1},
dataset_properties={
  'task': 1,
  'sparse': False,
  'multilabel': False,
  'multiclass': False,
  'target_type': 'classification',
  'signed': False})),
(0.080000, SimpleClassificationPipeline({'balancing:strategy': 'none', 'classifier:__choice__': 'sgd', 'data_preprocessing:categorical_transformer:categorical_encoding:__choice__': 'one_hot_encoding', 'data_preprocessing:categorical_transformer:category_coalescence:__choice__': 'minority_coalescer', 'data_preprocessing:numerical_transformer:imputation:strategy': 'most_frequent', 'data_preprocessing:numerical_transformer:rescaling:__choice__': 'standardize', 'feature_preprocessor:__choice__': 'select_percentile_classification', 'classifier:sgd:alpha': 1.6992296128865824e-07, 'classifier:sgd:average': 'True', 'classifier:sgd:fit_intercept': 'True', 'classifier:sgd:learning_rate': 'optimal', 'classifier:sgd:loss': 'log', 'classifier:sgd:penalty': 'l1', 'classifier:sgd:tol': 1.535384699341134e-05, 'data_preprocessing:categorical_transformer:category_coalescence:minority_coalescer:minimum_fraction': 0.24471105740962484, 'feature_preprocessor:select_percentile_classification:percentile': 39.91903776071659, 'feature_preprocessor:select_percentile_classification:score_func': 'f_classif'},
dataset_properties={
  'task': 1,
  'sparse': False,
  'multilabel': False,
  'multiclass': False,
  'target_type': 'classification',
  'signed': False})),
]

這樣就可以使用由auto sklearn所訓練出來的模型進行部署, 然後提供推論的結果. 但以目前來說還不建議直接將auto ML訓練出來的模型放在正式環境上做部署. 比較建議的方式是先使用auto ML工具產出具有基本準確度的模型與參數, 然後由人工進行參數的調整訓練模型以再次提高模型的準確度, 這樣會是比較好的方式.


上一篇
Auto ML簡介
下一篇
第四個範例-使用好用圖形化介面軟體執行口罩物件辨識
系列文
AI平台初學者工作坊: 從training、tracking到serving30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言