iT邦幫忙

2022 iThome 鐵人賽

DAY 12
0
AI & Data

30天AI馴獸師之生存日記系列 第 12

【Day 12】RL | HW12

  • 分享至 

  • xImage
  •  

a) Choose one algorithm from REINFORCE with baseline、Q Actor-Critic、A2C, A3C or other advance RL algorithms and implement it.

b) Please explain the difference between your implementation and Policy Gradient

c) Please describe your implementation explicitly (If TAs can’t understand your description, we will check your code directly.
(a)我實作 DQN 算法
(b)Please explain the difference between your implementation and Policy Gradient
在 DQN 的算法中,比 policy gradient 多實作了 Replay experience memory 的機制,就像是 Q learning 一樣,這邊 replay memory 扮演記憶池的功能,在訓練的過程中隨機加入之前的經驗,讓 policy model 學習;另外在 DQN 的 agent 中實作兩個一模一樣的 model,main_q_network 和 target_q_network,其中 main_q_network 就像是 policy gradient 裡面的 policy model 一樣,使用的是最新的參數,而 target_q_network 則是在迭代一定次數後,有實作一個 synchronized 的 function,將 main network 的參數複製給 target q network。
(c)
這次實作中,最重要的 module 有以下三個:

  1. Policy model
    這個 model 會吃進 state,然後學習預測下一步的 action 是什麼。
    使用三層 fully connected layers,並使用 relu 當作 activation function。
  2. Replay Experience Memory
    在一定的 capacity 之下,儲存 'state', 'action', 'next_state', ‘reward' 這些 transition 的資訊。並實作 push() 和 sample() function。前者是把過去的 experience 儲存在 memory buffer 中,後者則是從 memory buffer output batch size 大小的 experience ,並回傳。
  3. Agent
    在 initialization 中,會先初始化 memory buffer,以及兩個一模一樣的 policy model,分別是 main_q_nn 和 target_q_nn ,並且我使用 RMSprop 當作 policy model 的 optimizer。
    接下來就實作 7 個不同的 function,分別是 get_action() ,就是輸入 state 給 main_q_nn 這個 model 來得到接下來要做的 action 預測,或者是 random 回傳一些 action。決定的條件是看 epsilon 這個值,這個值得計算方式是:EPS_END + (EPS_START - EPS_END) * np.exp(-1. * steps_done / EPS_DECAY)。
    memorize()這個 function 是用來把現在的 experience 塞到 memory buffer。
    make_minibatch() 則是取得 memory 中的一些過往經驗後打包成 batch 回傳。
    update_q_function() 則主要是要呼叫 get_expected_state_action_values() 和 update_main_q_nn() 這兩個 function。
    update_main_q_nn() 這個 function 主要是用smooth_l1_loss計算loss,並 update main_q_nn model。
    update_target_q_function() 則是會在跑完一個 batch 之後,直接把當時 main_q_nn 的 weight 直接複製到 target_q_nn。

上一篇
【Day 11】Explainable AI | HW 9
下一篇
【Day 13】Life-Long learning 1
系列文
30天AI馴獸師之生存日記15
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言