iT邦幫忙

2022 iThome 鐵人賽

DAY 21
0
自我挑戰組

高中生也可以!利用強化學習讓機器人動起來!系列 第 21

D21:stable_baselines3範例&PPO演算法基本API

  • 分享至 

  • xImage
  •  

昨天不知道各位有沒有更加了解stable_baselines3這個模組了,今天要直接帶大家來看看官方文檔中的一些範例。藉此讓各位對強化訓練有基本的認識,基本上改成自定義環境也只是把環境id改掉而已。其他基本上可以不用更動。

範例

首先匯入所需的模組,然後make_vec_env()是可以一次使用四個一樣的環境,不過此專案使用gym.make()就好了。

接下來會先宣告演算法的使用PPO("MlpPolicy", env, verbose=1),第一個參數為policy,PPO可以使用MlpPolicy, CnnPolicy,前者的觀察值為一個陣列,後者則是直接輸入一張圖片。

宣告好演算法之後就可以讓他學習啦。learn(total_timesteps=25000)代表這個演算法在指定環境中要交互25000次(total_timesteps),不過通常不會那麼少,救我而言都是先跑1萬次或者10萬次看看有沒有甚麼bug之類的需要針對環境微調的部分,若無太大問題的話訓練次數就是大約50萬次起跳,複雜的任務幾百萬次都不在話下。

最後訓練好後要將模型儲存起來,接下來就可以評估模型訓練的情況了。

import gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

# Parallel environments
env = make_vec_env("CartPole-v1", n_envs=4)

#訓練演算法
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25000)
model.save("ppo_cartpole")
del model # remove to demonstrate saving and loading

#評估演算法
model = PPO.load("ppo_cartpole")
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

各位可以先跑跑看這個程式,等他訓練完後會跳出一個視窗。這就是env.render()函數將交互過程給可視化了,此環境訓練過程中似乎無法看到訓練情況,其他環境也都可以試試看。

https://ithelp.ithome.com.tw/upload/images/20220922/20151029u1ytULiCEo.png

PPO參數

近端策略優化(Proximal Policy Optimization, PPO)是一個強化學習演算法,這種演算法可以很好的幫我們完成一些任務。其數學建模跟原理以及演變等各位可以參考底下的網址去了解看看,這邊直接帶各位去了解這個演算法在模組中的一些基本參數。

以下會列幾個常用的參數,在宣告PPO演算法時可以調整,每個演算法官方也都有詳細的說明。有興趣都可以再去看看。前兩個參數為必填參數,其他為可選參數。

第一個參數就是policy:可以使用”MlpPolicy”, “CnnPolicy”等等,MlpPolicy就是觀察值為一串陣列時用的;CnnPolicy則是觀察值為一張圖片的二維陣列時用的。

第二個參數為強化學習環境env

verbose(int):0為無輸出,1為訓練信息,2為debug信息。預設為1。

seed(int):指定亂數種子。在實驗上相當重要,可以使每次訓練中的「隨機」會被確定,使得每次訓練結果都不會不同。

n_steps(int):過了幾個step後要更新一次環境,計算一些如損失等等的訊息。

--我比較常用的參數就是以上幾個,接下來也會介紹其他的參數。

learning_rate(float):學習率,定義區間為0~1。

batch_size(int):小批量大小。

n_epochs (int) :優化代理人損失時的 epoch 數

其他有一些為演算法中影響數學運算的超參數:

gamma(float):折扣係數

clip_range (float):剪輯參數,它可以是目前剩餘進度的數值(從1到0)。

clip_range_vf (float):值函數的剪輯參數,它可以是目前剩餘進度的數值(從1到0)。預設是None如果 則不會對值函數進行剪輯。此剪輯取決於獎勵縮放!

ent_coef (float):計算損失函數時的熵係數

vf_coef (float):計算損失時的價值函數係數

max_grad_norm (float):梯度裁剪的最大值

以上為超參數的介紹,不過還有一些我也沒用過的就不拿來介紹了。接下來來看看演算法模型可使用的方法:

learn(total_timesteps:int):訓練模型,訓練次數為total_timesteps的次數。
save("model name"):儲存訓練好的模型。
load("model name"):載入訓練好的模型。

model.predict(obs):輸入觀察值後會生成action值跟states值,不過通常只會使用到action值。

結語

花了兩天來介紹stable-baselines3,希望各位都可以藉此更加認識這個模組,當然還有很多內容跟演算法沒有介紹到,如果各位有興趣的話可以再去參考官方文檔上的說明。明天開始要來訓練我們自己寫的環境了。等那麼久終於要開花結果了,那各位我們就明天見囉!

參考資料、網址

stable-baselines3 PPO介紹
https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
PPO介紹
https://blog.csdn.net/wangwei19871103/article/details/100398786


上一篇
D20:強化學習模組—stable-baselines3介紹
下一篇
D22:使用強化學習訓練自己的環境
系列文
高中生也可以!利用強化學習讓機器人動起來!30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言