昨天不知道各位有沒有更加了解stable_baselines3這個模組了,今天要直接帶大家來看看官方文檔中的一些範例。藉此讓各位對強化訓練有基本的認識,基本上改成自定義環境也只是把環境id改掉而已。其他基本上可以不用更動。
首先匯入所需的模組,然後make_vec_env()
是可以一次使用四個一樣的環境,不過此專案使用gym.make()
就好了。
接下來會先宣告演算法的使用PPO("MlpPolicy", env, verbose=1)
,第一個參數為policy,PPO可以使用MlpPolicy, CnnPolicy,前者的觀察值為一個陣列,後者則是直接輸入一張圖片。
宣告好演算法之後就可以讓他學習啦。learn(total_timesteps=25000)
代表這個演算法在指定環境中要交互25000次(total_timesteps
),不過通常不會那麼少,救我而言都是先跑1萬次或者10萬次看看有沒有甚麼bug之類的需要針對環境微調的部分,若無太大問題的話訓練次數就是大約50萬次起跳,複雜的任務幾百萬次都不在話下。
最後訓練好後要將模型儲存起來,接下來就可以評估模型訓練的情況了。
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
# Parallel environments
env = make_vec_env("CartPole-v1", n_envs=4)
#訓練演算法
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25000)
model.save("ppo_cartpole")
del model # remove to demonstrate saving and loading
#評估演算法
model = PPO.load("ppo_cartpole")
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
各位可以先跑跑看這個程式,等他訓練完後會跳出一個視窗。這就是env.render()
函數將交互過程給可視化了,此環境訓練過程中似乎無法看到訓練情況,其他環境也都可以試試看。
近端策略優化(Proximal Policy Optimization, PPO)是一個強化學習演算法,這種演算法可以很好的幫我們完成一些任務。其數學建模跟原理以及演變等各位可以參考底下的網址去了解看看,這邊直接帶各位去了解這個演算法在模組中的一些基本參數。
以下會列幾個常用的參數,在宣告PPO演算法時可以調整,每個演算法官方也都有詳細的說明。有興趣都可以再去看看。前兩個參數為必填參數,其他為可選參數。
第一個參數就是policy:可以使用”MlpPolicy”, “CnnPolicy”
等等,MlpPolicy就是觀察值為一串陣列時用的;CnnPolicy則是觀察值為一張圖片的二維陣列時用的。
第二個參數為強化學習環境env
verbose(int):0為無輸出,1為訓練信息,2為debug信息。預設為1。
seed(int):指定亂數種子。在實驗上相當重要,可以使每次訓練中的「隨機」會被確定,使得每次訓練結果都不會不同。
n_steps(int):過了幾個step後要更新一次環境,計算一些如損失等等的訊息。
--我比較常用的參數就是以上幾個,接下來也會介紹其他的參數。
learning_rate(float):學習率,定義區間為0~1。
batch_size(int):小批量大小。
n_epochs (int) :優化代理人損失時的 epoch 數
其他有一些為演算法中影響數學運算的超參數:
gamma(float):折扣係數
clip_range (float):剪輯參數,它可以是目前剩餘進度的數值(從1到0)。
clip_range_vf (float):值函數的剪輯參數,它可以是目前剩餘進度的數值(從1到0)。預設是None如果 則不會對值函數進行剪輯。此剪輯取決於獎勵縮放!
ent_coef (float):計算損失函數時的熵係數
vf_coef (float):計算損失時的價值函數係數
max_grad_norm (float):梯度裁剪的最大值
以上為超參數的介紹,不過還有一些我也沒用過的就不拿來介紹了。接下來來看看演算法模型可使用的方法:
learn(total_timesteps:int):訓練模型,訓練次數為total_timesteps的次數。
save("model name"):儲存訓練好的模型。
load("model name"):載入訓練好的模型。
model.predict(obs):輸入觀察值後會生成action值跟states值,不過通常只會使用到action值。
花了兩天來介紹stable-baselines3,希望各位都可以藉此更加認識這個模組,當然還有很多內容跟演算法沒有介紹到,如果各位有興趣的話可以再去參考官方文檔上的說明。明天開始要來訓練我們自己寫的環境了。等那麼久終於要開花結果了,那各位我們就明天見囉!
stable-baselines3 PPO介紹
https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
PPO介紹
https://blog.csdn.net/wangwei19871103/article/details/100398786