如 Tensorflow 一般,PyTorch 也擁有自己的 distribute strategy 來做模型平行化和資料平行化的運算。今天我們就來看一下 PyTorch 如何進行平行化:
有的時候為了達到較高的精準度,模型會有複雜度增加的傾向,而往往過於複雜的模型,通常規模也相當龐大而以致於不易放入單一個 GPU,這個時候就可以利用模型平行化。與等一下會談到的資料平行化相似,將模型切為可獨立執行的幾個小模型。我們可以用一個具有十層的模型來簡單說明模型平行化如何發生在兩個 GPUs 上。在非平行化的模型中,每一個 replica 會有完整模型的拷貝版,而在平行化的模型中,每一個 replica 會持有各五層的部分模型。我們用一個玩具模型來說明該如何設計原始碼達成模型平行化。
import torch
import torch.nn as nn
import torch.optim as optim
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = torch.nn.Linear(10, 10).to('cuda:0')
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5).to('cuda:1')
def forward(self, x):
x = self.relu(self.net1(x.to('cuda:0'))) # 指派到第一個 GPU
return self.net2(x.to('cuda:1')) # 指派到第二個 GPU
看起來真是輕鬆容易,因為隨後建立的 optimizer 會協調程序計算梯度,使用者只要注意傳入的標注要和訓練特徵一致,預防計算出錯的損失值即可。
接下來我們則用一個實例來說明該如何利用最少的原始碼更改,達成模型平行化。這個模型就是 torchvision 的 ResNet50。
我們將會把 ResNet50 對兩個 GPU 做模型平行化。在這過程中,我們必須繼承原來的 ResNet50 類別,將模型分成平均分配在兩個 GPU 上,複寫 forward
方法使在兩個 GPU 上計算完成的姐果可以整合成為最終值。在下面的原始碼,我們將上述的想法實作。
from torchvision.models.resnet import ResNet, Bottleneck #import ResNet 類別
num_classes = 1000
class ModelParallelResNet50(ResNet): # 複寫 ResNet
def __init__(self, *args, **kwargs):
super(ModelParallelResNet50, self).__init__(
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, *args, **kwargs)
self.seq1 = nn.Sequential(
self.conv1,
self.bn1,
self.relu,
self.maxpool,
self.layer1,
self.layer2
).to('cuda:0') # 部分模型一分配到第一個 GPU
self.seq2 = nn.Sequential(
self.layer3,
self.layer4,
self.avgpool,
).to('cuda:1') # 部分模型一分配到第二個 GPU
self.fc.to('cuda:1') #最後的全連接在第一個 GPU
def forward(self, x):
x = self.seq2(self.seq1(x).to('cuda:1')) #
return self.fc(x.view(x.size(0), -1))
根據上面的原始碼,我們的確成功地將模型指派到相對應的 GPUs,然而這個模型有一個問題,那就是將 seq1 的中繼資料從 'cuda:0'
搬到 'cuda:1'
,而造成延遲。從 PyTorch 的官方資料我們可以看到 ModelParallelResNet50 還要比 single GPU 還慢一點。
模型平行化可以和資料平行化一起運行取得最佳的效能。這裡介紹該如何將批次資料轉成更小的批次資料分散到每個 GPU 上執行。下面原始碼會建立一個 PipelineParallelResNet50
類別,該類別為 ModelParallelResNet50
的子類別。
class PipelineParallelResNet50(ModelParallelResNet50):
def __init__(self, split_size=20, *args, **kwargs):
super(PipelineParallelResNet50, self).__init__(*args, **kwargs)
self.split_size = split_size
def forward(self, x):
splits = iter(x.split(self.split_size, dim=0)) # 將輸入再拆成 len(x) / split_size
s_next = next(splits)
s_prev = self.seq1(s_next).to('cuda:1')
ret = []
for s_next in splits:
# A. s_prev runs on cuda:1
s_prev = self.seq2(s_prev)
ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))
# B. s_next runs on cuda:0, which can run concurrently with A
s_prev = self.seq1(s_next).to('cuda:1')
s_prev = self.seq2(s_prev)
ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))
return torch.cat(ret)
接著在 forward
的方法內,實作資料分割的部分,大家可以看到,增加了資料管線的部分贏得許多的執行時間縮減,因此是最有效率的。
Data Parallel 主要是將訓練資料根據可得的運算資源,分割成一或更多的訓練資料子集,分配到運算單元中執行。在這個平行化策略,模型在每一個運算單元都會存有一份拷貝,在每一個運算單元完成。而 Distributed Data Parallel 則需要實作 communications 的部分。兩者的差異可以列舉如下:
知道了這兩者的分別,我們先來看如何建立一個 Data Parallel 的模型。Data Parallel 的實作項目包裝在 tf.nn
內,我們可以用 nn.DataParallel
重新包覆建立的無資料名行化模型為資料平行化模型,呼叫有如這段程式碼:model = nn.DataParallel(model)
在這裡要注意的是,使用者必須要手動去檢查是否有多個 GPU,藉著呼叫函示torch.cuda.device_count()
來得到 GPU 的個數。
PyTorch 的 nn.DataParallel
將許多關於多線程的細節都實作且封裝在模組內,使用者無須擔心實作細節,包括如何 scatter 資料和 gather 計算結果,只需呼叫 nn.DataParallel(model)
即可,非常容易上手。
Distributed Data Parallel 又被稱為 DDP,其實作已經被封裝到 torch.distributed
模組內。為了能讓 DDP 運作,使用者需要先建立且初始化 process groups(setup
函式),並提供當 process groups 停止後的清理原始碼()。至於如何將已存在的 nn.Module
轉為可執行Distributed Data Parallel 的模型,做法和Data Parallel 一樣簡單,只要用 DistributedDataParallel
包覆模型物件即可。我們現在來看一下原始碼:
import os
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
def setup(rank, world_size):
# 準備 process groups 運行的環境
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# 初始化 process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# 明確的指定 seed 以保證兩個程序都會從同樣的任意權重和偏移參數開始
torch.manual_seed(42)
def cleanup():
# 清理 process groups
dist.destroy_process_group()
在 setup()
中,我們可以設定一些環境變數,但最重要多是藉著呼叫 torch.distributed
模組函式--dist.init_process_group
初始化 process group。另外在結束時,也需要呼叫 cleanup()
。在 cleanup() 內容中,最重要的應該是呼叫
torch.distributed模組函式 --
dist.destroy_process_group()`結束 process groups。而在 8 個 GPUs 的訓練原始碼則如下:
from torch.nn.parallel import DistributedDataParallel as DDP
def demo_basic(rank, world_size):
setup(rank, world_size)
# 為這個 process 準備環境, rank 1 使用 GPUs [0, 1, 2, 3] 而
# rank 2 使用 GPUs [4, 5, 6, 7].
n = torch.cuda.device_count() // world_size
device_ids = list(range(rank * n, (rank + 1) * n))
# 建立模型並將之移到 device_ids[0](GPU:0)
model = ToyModel().to(device_ids[0])
# 處理輸出的設備預設為 device_ids[0](GPU:0)
ddp_model = DDP(model, device_ids=device_ids)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_ids[0])
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
和Data Parallel 一樣,大部分底層的邏輯都被封裝起來,使用者無需關心底層的實作情況便能輕鬆使用。但,事情並不沒有如此美好,還有一些是需要使用者注意的,如,使用者需調整工作負載狀況。為了防止最有效率的程序與最沒有效率的程序不會相差過大,而耗費較長的時間在等待最差效率者的完成工作,使用者需要提供一個足夠大的 timeout 時間給init_process_group
。這個問題又被稱為 skewed processing speeds,常見於網路的延遲,資源競爭和不可預期的工作負載高峰等。
此外還有儲存和載入模型的問題,在使用 Distributed Data Parallel 時,只允許一個程序做寫入模型的動作,一但被寫入後,所有的程序都可以讀入。下面是 PyTorch 官方網頁 demo_checkpoint 韓式的一部分。因為 demo_checkpoint
和 demo_basic
有點相似,所以在下面僅列出不相同的地方,完整的程式碼可以到官方網頁去觀看。
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
用其中一個 process 儲存
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
另外, optimizer 要保證所有的參數都會在同一迭代等待同步化處理梯度更新,因此若有一個程序還在寫入的狀態時,其他的程序不可讀入模型而造成參數值不盡相同。為了能完成這兩點,使用者必須提供一個設計良好的 map_location
給 load_state_dict
。該物件會防止程序踰矩侵佔其他程序所使用的設備,也防止所有的程序都使用同一個設備。
下面就是如何使用map_location
# 使用 barrier() 去阻擋其他 process 載入模型,當 process 0 還在儲存時
dist.barrier()
# 適當調整 map_location
rank0_devices = [x - rank * len(device_ids) for x in device_ids]
device_pairs = zip(rank0_devices, device_ids)
map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location))
上面的原始碼就是 main 的入口函式,在這裡會開始 process groups 而使模型可以在多程序平行計算中執行。
# 使用 torch 提供的 multiprocessing,可以免除一些 serialization 的問題
import torch.multiprocessing as mp
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == "__main__":
run_demo(demo_basic, 2)
run_demo(demo_checkpoint, 2)