大家好,我是毛毛。
今天是Day 27
這次是Acrobot~ ヽ(✿゚▽゚)ノ
上圖就是Gym中Acrobot-v1的實驗圖
import gym
import random
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
引入需要的相關套件。
這邊就不多解釋了,跟前面都很類似。
def __init__(self, s_size, a_size, e = 0.01, g = 0.95, memory_size = 2000, learning_rate = 0.001):
self.s_size = s_size
self.a_size = a_size
self.e = e
self.e_decay = 0.99
self.g = g
self.memory = deque(maxlen = memory_size)
self.learning_rate = learning_rate
self.eval_model = self.__build_model()
self.target_model = self.__build_model()
這邊一樣在設定神經網路的相關參數。
def __build_model(self):
model = Sequential()
model.add(Dense(32, input_dim = self.s_size, activation = 'relu'))
model.add(Dense(64, activation = 'relu'))
model.add(Dense(32, activation = 'relu'))
model.add(Dense(self.a_size, activation = 'linear'))
model.compile(loss = "mean_squared_error", optimizer = Adam(lr = self.learning_rate))
return model
這邊建立eval_net和target_net,並透過compile函數定義損失函數(loss)、優化函數(optimizer)及成效衡量指標(mertrics)。
def target_replacement(self):
self.target_model.set_weights(self.eval_model.get_weights())
將target_net的權重用eval_net的權重來更新。
def store_transition(self, state, action, reward, new_state, terminal):
self.memory.append((state, action, reward, new_state, terminal))
儲存經驗。
def replay_transition(self, batch_size):
minibatch = random.sample(self.memory, batch_size)
for state, action, reward, new_state, terminal in minibatch:
target = self.eval_model.predict(state)
if terminal:
target[0][action] = reward
else:
a = self.eval_model.predict(new_state)[0]
t = self.target_model.predict(new_state)[0]
target[0][action] = reward + self.g * t[np.argmax(a)]
self.eval_model.fit(state, target, epochs = 1, verbose = 0)
隨機地選取過去的經驗訓練model。
def choose_action(self, state):
if random.random() < self.e:
return random.choice(range(self.a_size))
act = self.eval_model.predict(state)
return np.argmax(act[0])
透過epsilon-greedy來選擇action。
def save_weights(self):
self.eval_model.save_weights('./dqn_weights.h5')
儲存eval_net的權重。
def load_weights(self):
self.eval_model.load_weights('./dqn_weights.h5')
self.target_model.load_weights('./dqn_weights.h5')
讀取之前訓練的權重。
今天一樣先把DQN的部分完成啦,明天在來PO Acrobot的執行結果 0(:3 )~ ('、3_ヽ)_
大家明天見