大家好,我是毛毛。
今天是Day 24
昨天完成了CartPole神經網路的部份了,今天實際跑跑看~ ヽ(✿゚▽゚)ノ
def load_weights(self):
self.evalNet.load_weights("dqn.weights.h5", by_name=True)
這個是今天想到補上的,這樣就可以將之前訓練保留的權重,在讀進來ヽ(✿゚▽゚)ノ
def get_weights(self):
return self.evalNet.get_weights()
這個是我用來看模型權重的內容,目的是要看看有沒有修改到權重
import gym
import os
env = gym.make('CartPole-v0')
model = DQN(num_actions=2, num_features=4, e_greedy=0.99, e_greedy_increment=0.002)
model.summary()
print("Old weights: ", model.get_weights())
# 要檢查的檔案路徑
filepath = "./dqn.weights.h5"
# 檢查檔案是否存在
if os.path.isfile(filepath):
print("Exist")
model.load_weights()
else:
print("Not exist")
print("Revised weights: ", model.get_weights())
obs = env.reset()
gym.make()
:建立CartPole-v0的環境env.reset()
:讓環境在一開始初始化,還有在遊戲結束的時候重置環境step = 0
for times in range(4000):
obs = env.reset()
total_reward = 0
while True:
env.render()
action = model.choose_action(obs[None, :])
obs_, reward, terminal, _ = env.step(action)
total_reward += reward
if total_reward < 200:
model.store_transition(obs, action, reward, obs_, terminal)
else:
print('Nice!! 。:.゚ヽ(*´∀`)ノ゚.:。')
if (step > 1000) and (step % 3 == 0):
model.learn()
obs = obs_
if terminal:
env.render()
#print(reward, terminal)
break
step += 1
print('Time: ', times, 'Reward: ', total_reward, 'Steps: ', step)
if times%200 == 199:
model.evalNet.save_weights('./dqn.weights.h5')
env.close()
env.render()
:畫面視覺化呈現,但是它只會出現呼叫當下的畫面,所以如果要持續出現,就需要寫個迴圈env.step()
:執行動作,輸入0或1的值,數字分別代表左(0)、右(1)。
執行結果
剛建完的模型的權重
讀取之前訓練的模型的權重
Console的執行結果
今天就到這啦~ 0(:3 )~ ('、3_ヽ)_
大家明天見
Reference