iT邦幫忙

第 12 屆 iT 邦幫忙鐵人賽

DAY 27
0
AI & Data

Knock Knock! Deep Learning系列 第 27

Day 26 / DL x RL / 小試身手的 Project Talk —— DeepCueLearning

上一篇我們用 CartPole 了解如何實作 RL。今天來介紹自己的 RL project,更加深實作概念吧。

RL 非常適合應用在訓練遊戲 AI,而我在 Stanford CS229 的 project 正是利用 RL 訓練 AI 玩撞球遊戲。總共用到了三種 model,過程中也感受到 RL 真的不太容易訓練好。

接著我們就來看看這個 project,並認識不同的 RL algorithm 吧。

程式碼、report、poster 都在 GitHub: pyliaorachel/CS229-pool

Problem

目標很簡單,就是讓我們的 model 學會打撞球。

pool game
—— Pool game。[1]

具體來說的 formulation 如下:

  • Agent:DeepCueLearning AI。
  • Environment:撞球模擬環境。我們用了這個 open source 撞球遊戲作為訓練環境。
  • State:每個球在撞球桌上的 (x, y) 位置。
  • Action:擊球角度(0 - 360)和力道(0 - 1)。
  • Reward:球進袋 +5,白球沒有擊中球 -1,其他則 0。

Approach

我們介紹三種 RL algorithm 來訓練我們的 AI。

Q-Table

Q-table 在 Day 25 介紹過了。簡單來說,Q-table 是用 lookup table 來 approximate Q-value function,並用 Q-learning 訓練的一個方法。

Q-learning update rule 再看一次:

https://ithelp.ithome.com.tw/upload/images/20201011/20130687N5bYuluXCL.png

跟前面提到的不同的是,在這個任務裡,我們的 state 和 action 都是連續值,兩個都要 discretize 才能建立一個 lookup table。

實作中,我們將 state 的 x, y 各 discretize 成 50 個區間,擊球角度 discretize 成 18 個區間,力道 discretize 成 5 個區間。這樣會有 50 x 50 個 state 和 18 x 5 個 action,lookup table 大小為 50 x 50 x 18 x 5。

Deep Q-Network (DQN)

DQN 前幾篇介紹 Atari AI 有介紹過,簡單來說是用 neural network 來 approximate Q-value function。這樣就不用像 Q-table 一樣必需 discretize input 了。

Asynchronous Advantage Actor-Critic (A3C)

Asynchronous Advantage Actor-Critic (A3C) 有個很讓人摸不著頭緒的名字,但原理並不難。

在 RL 可以訓練兩種 network,一種是 policy network,input state 預測每個 action 的機率。另一種是 value network,input state 預測 state value 或 Q-value。而 A3C 結合了兩種方式,訓練一個 network 預測 action 機率以及 state value。

Actor-Critic 指的就是讓 policy network 當 actor 負責選擇 action,value network 當 critic 負責估計 state 的好壞,而 actor 會根據 critic 給的 value 來更新 model。

Advantage 指的是 actor 在 state s 做出 action a 的優勢。具體來說,大概可以用做出 action a 獲得的 reward R 及 critic 給的 value V(s) 之間的差距 A = R - V(s) 來表示。而 A3C 並不像一般 model 用 discounted reward 而是用 advantage 來訓練。也就是說,比起告訴他你選的 action 有多好,告訴他你選的 action 比預期來得好多少,會讓訓練更有效率,因為這樣 model 更可以知道自己的不足。

最後 asynchronous 指的是我們讓很多 worker 同時跟環境互動汲取經驗學習,比起只有一個 worker 更有效率。具體來說,A3C 會有一個 global network,還有很多 worker 同時跟環境互動,並且把經驗 (state, action, reward) 回報給 global network。大概收集足夠的量,global network 就會根據這些經驗進行學習。如此一來可以有效汲取大量且多樣的經驗,network 也能訓練得更好。

A3C 的 action output 可以是 continuous 或 discrete。Continuous 的話我們會預測一個 normal distribution 的 mean 和 variance,而 discrete 的話就跟 classification 一樣預測每個 discrete action 的機率。我們兩種都有實作做比較。

Results & Analysis

我們總共有五個 method,Q-table、DQN、A3C with continuous action、A3C with discrete action 四個主要 algorithm,以及 random action 作為 baseline。

我們分別做了兩顆球和四顆球的實驗,其中一顆是白球。

兩顆球實驗中,訓練了 1000 episode 結果如下:

2 ball average reward
—— 兩顆球實驗中,五個 method 的平均 reward 趨勢。

2 ball result table
—— 兩顆球實驗中,五個 method 的平均 reward、訓練時間和空間。

首先可以看到的是 Q-table 成果最好。我們發現 Q-table 學會用六下打進某個洞,所以效果不算太差。但缺點就是訓練時間太長,以及 lookup table 佔用太多空間了,而這只是兩顆球的實驗。因此 Q-table 不太適合 generalize 到更複雜的設定中。

DQN 和 A3C 效果滿普通的,訓練也不太穩定。A3C continuous 一開始還學得比隨便亂打還差,肯定是中途學壞了,後來可能跳脫 local minimum 才逐漸有起色。DQN 和 A3C discrete 都是預測 discrete action,類似於做 classification,我們也發現這種訓練比預測 normal distribution 穩定多了。而且 neural network 的 discritization 也能切得比 Q-table 細,所以會推薦大家要玩 RL 的話先從訓練 discrete action output 開始。

A3C 的訓練時間和佔用空間都非常有效率,其實可以訓練更久試試看,效果應該不錯,只是當時沒有時間再訓練下去了。此外,exploration 的設定可以再提高一些,讓 model 嘗試更多種打法,比較不容易被困在 local minimum。

我們其實還有四顆球的實驗結果,不過兩顆球就夠糟了,四顆球的結果也不太有意義。有興趣可以去看 report。

Code

最後一樣簡單看一下重點部位的 code。我們看一下 A3C with discrete action 的就好。

首先是 global network:

class Net(nn.Module):
    def __init__(self, s_dim, a_dim, h_dim):
        super().__init__()

        self.s_dim = s_dim
        self.a_dim = a_dim

        # Actor
        self.a1 = nn.Linear(s_dim, h_dim)
        self.a21 = nn.Linear(h_dim, a_dim[0])
        self.a22 = nn.Linear(h_dim, a_dim[1])

        # Critic
        self.c1 = nn.Linear(s_dim, h_dim)
        self.v = nn.Linear(h_dim, 1)

    def forward(self, x):
        a1 = F.relu(self.a1(x))
        logits1 = self.a21(a1)
        logits2 = self.a22(a1)
        c1 = F.relu(self.c1(x))
        values = self.v(c1)

        return logits1, logits2, values

Actor 和 critic 都是簡單的 feed-forward network。Actor 預測兩種 action,擊球角度和力道。

再來是每個 worker 和環境互動:

class Worker(mp.Process):
    def __init__(self, gnet, opt, global_ep, global_ep_r, env_params, hidden_dim, episodes, episode_length, model_path=None):
        super().__init__()

        self.env_params = env_params # game environment setting
        self.hidden_dim = hidden_dim
        self.gnet = gnet # global net
        self.opt = opt # optimizer
        
        self.episodes = episodes # total episodes
        self.episode_length = episode_length # number of timesteps in an episode
        self.g_ep = global_ep # total episodes so far across all workers
        self.g_ep_r = global_ep_r # total average rewards so far across all workers

        self.gamma = 0.8 # reward discount factor
        
    def run(self):
        # set up game environment
        env = PoolEnv(**self.env_params)
        
        # worker network
        self.lnet = Net(env.state_space.n, env.action_space.n, self.hidden_dim)

        # start gathering experience
        total_steps = 1
        while self.g_ep.value < self.episodes:
            next_state = env.reset() # initialize state
            state_buffer, action_buffer, reward_buffer = [], [], [] # keep experience in buffers
            rewards = 0 # accumulate rewards for each episode
            done = False
            
            # start episode
            for t in range(self.episode_length):
                state = norm_state(next_state, env.state_space.w, env.state_space.h) # normalize state values
                # choose action and do action
                action = self.lnet.choose_action(state)
                next_state, reward, done = env.step(action)
                rewards += reward
                done = done or t == self.episode_length - 1

                # store experience in buffers
                action_buffer.append(action)
                state_buffer.append(state)
                reward_buffer.append(norm(reward, env.max_reward, env.min_reward))

                # train
                if total_steps % GLOBAL_UPDATE_RATE == 0 or done:
                    # push experience, update global network, clone back global network
                    push_and_pull(self.opt, self.lnet, self.gnet, done, next_state, state_buffer, action_buffer, reward_buffer, self.gamma)
                    state_buffer, action_buffer, reward_buffer = [], [], []

                # transition to next state
                state = next_state
                total_steps += 1

                if done:
                    # end of episode, update global information
                    record(self.g_ep, self.g_ep_r, rewards)
                    break

首先每個 worker 都繼承 mp.Processmp 是 PyTorch 的 multiprocessing package,利用它來為每個 worker 建立不同 thread 讓他們能 run in parallel。

每個 worker 的架構都跟 global network 相同。run 的時候開始和環境互動取得經驗,每 GLOBAL_UPDATE_RATE 個 step 就會把經驗 push 回 global network,訓練之後,再把 global network 複製回來。可以想像有很多 worker 都在同時做這件事,增加效率。

最後是 training:

def train(env_params,episodes=200, episode_length=50):
    # Global network
    env = PoolEnv(**env_params)
    gnet = Net(env.state_space.n, env.action_space.n, HIDDEN_DIM)
    opt = SharedAdam(gnet.parameters(), lr=LR) # global optimizer
    global_ep, global_ep_r = mp.Value('i', 0), mp.Value('d', 0.) # 'i': int, 'd': double

    # Parallel training
    workers = [Worker(gnet, opt, global_ep, global_ep_r, i, env_params, ACTION_BUCKETS, HIDDEN_DIM, episodes, episode_length, model_path)
               for i in range(mp.cpu_count() // 2)]
    for w in workers:
        w.start()
    for w in workers:
        w.join()

SharedAdam 是特別進去原本的 Adam optimizer 讓裡面的 state 可以有效共享在不同 worker 中。簡單來說,我們建立一個 global network,以及一群 worker thread 讓他們同時跑在不同 thread 上,每個 worker 就像上面看到的一樣,收集完一些經驗就丟給 global network 訓練。

結語

這算是我第二個 RL project,做完心得就是 RL 雖然好像很萬用很強大,但訓練起來真的需要一點經驗才能有好的成果。如果大家想要玩 RL,推薦先從簡單的開始建立經驗吧!

參考資料

  1. Pool game in Python
  2. (Mnih et al., 2016) Asynchronous Methods for Deep Reinforcement Learning
  3. Simple Reinforcement Learning with Tensorflow Part 8: Asynchronous Actor-Critic Agents (A3C)

上一篇
Day 25 / DL x RL / Hello Reinforcement Learning —— CartPole
下一篇
Day 27 / DL x RL / 令世界驚艷的 AlphaGo
系列文
Knock Knock! Deep Learning31

尚未有邦友留言

立即登入留言