iT邦幫忙

2023 iThome 鐵人賽

DAY 29
0
AI & Data

ML From Scratch系列 第 29

[Day 29] Deep Q-Network — 主題實作

  • 分享至 

  • xImage
  •  

昨天介紹 Deep Q-Network,今天我們透過 TUTORIAL 來講解 Deep Q-Network。

Impelmentation

Import Library

import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
  • gym:用於建立和管理OpenAI Gym環境。
  • collections.namedtuple:用於建立具名元組。
  • collections.deque:用於建立雙向隊列,用於存儲重播記憶。
  • itertools.count:用於生成無限遞增的計數器。
env = gym.make("CartPole-v1")

建立CartPole-v1環境

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

初始化matplotlib並檢查是否在IPython環境中運行。

檢查GPU是否可用,並選擇適當的設備(CPU或GPU)。

定義了一個名為Transition的具名元組,用於表示轉換(狀態、行動、下一個狀態和獎勵)。

Define function

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

用於存儲和樣本記憶的轉換。

class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

定義了DQN,用於建立Q-值近似的深度神經網絡。

  • 使用三個全連接層(層1、層2和層3)。
  • forward計算神經網絡的前向傳播。
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)

steps_done = 0

定義了訓練所需的超參數

  • 批次大小(BATCH_SIZE)
  • 折扣因子(GAMMA)、https://chart.googleapis.com/chart?cht=tx&chl=%5Cepsilon-greedy策略的https://chart.googleapis.com/chart?cht=tx&chl=%5Cepsilon起始值(EPS_START)和最終值(EPS_END)
  • https://chart.googleapis.com/chart?cht=tx&chl=%5Cepsilon衰減率(EPS_DECAY)
  • 目標網絡更新速率(TAU)
  • 學習率(LR)
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
  • state:當前的狀態,這是一個 PyTorch 張量。
  • sample = random.random():生成一個0到1之間的隨機數,用來比較是否採用隨機動作。
  • eps_threshold:根據 https://chart.googleapis.com/chart?cht=tx&chl=%5Cepsilon-greedy 策略計算 https://chart.googleapis.com/chart?cht=tx&chl=%5Cepsilon(探索率)的閾值。https://chart.googleapis.com/chart?cht=tx&chl=%5Cepsilon 的值從 EPS_START 開始,隨著訓練的進行而以指數方式遞減,直到 EPS_END
  • steps_done += 1:每次選擇動作時,增加 steps_done 變數,用於計算 https://chart.googleapis.com/chart?cht=tx&chl=%5Cepsilon 的衰減。

演算法

  • 如果 sample > eps_threshold:這表示根據 https://chart.googleapis.com/chart?cht=tx&chl=%5Cepsilon-greedy 策略,應該選擇已知的最佳動作。使用 with torch.no_grad(),禁用梯度計算,然後使用策略網絡 policy_net 預測在當前狀態下每個動作的 Q 值,然後選擇具有最高 Q 值的動作(max(1)[1])。

  • 如果 sample <= eps_threshold:這表示根據 https://chart.googleapis.com/chart?cht=tx&amp;chl=%5Cepsilon-greedy 策略,應該進行探索,即選擇一個隨機動作。在這種情況下,使用 env.action_space.sample() 從環境的動作空間中隨機選擇一個動作。

episode_durations = []

def plot_durations(show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

繪製訓練過程中持續時間的函數,用於可視化智能體在每個訓練回合(episode)中的性能。

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

優化深度 Q 網絡 (DQN) 模型的函數 optimize_model(),用於執行一個訓練步驟。

使用深度神經網絡逼近 Q 函數,計算損失,並通過梯度下降法來更新策略網絡的權重,以最大化預期回報。

DQN

if torch.cuda.is_available():
    num_episodes = 600
else:
    num_episodes = 50

for i_episode in range(num_episodes):
    # Initialize the environment and get it's state
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    for t in count():
        action = select_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break

print('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()

訓練 Deep Q-Network (DQN)。

通過與環境互動、存儲記憶、優化策略網絡和軟更新目標網絡,實現了深度 Q 學習的訓練過程。

通過不斷迭代,模型將學習最佳策略,以最大化總體回報。

Reference


上一篇
[Day 28] Deep Q-Network — 背後理論
下一篇
[Day 30] Deep Q-Network — 解決真實問題
系列文
ML From Scratch31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言