W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗值獎勵
在深度學(xué)習(xí)模型的訓(xùn)練過程中,數(shù)據(jù)并行是一種常見的加速訓(xùn)練的方法。PyTorch 提供了 DistributedDataParallel
(DDP)模塊,用于實現(xiàn)高效的多 GPU 分布式訓(xùn)練。DDP 通過使用 torch.distributed
包中的通信集合來同步梯度、參數(shù)和緩沖區(qū),支持在單機或跨多機的多進程環(huán)境中進行訓(xùn)練。本文將帶您深入了解 PyTorch 分布式數(shù)據(jù)并行的基本概念、實現(xiàn)方法和優(yōu)化技巧,幫助您高效地利用多 GPU 資源加速模型訓(xùn)練。
DDP 與 DataParallel
相比具有顯著的優(yōu)勢:
DataParallel
不支持。當模型規(guī)模過大無法放入單個 GPU 時,DDP 能更好地處理這種情況。DataParallel
是單進程多線程,僅限單機;DDP 是多進程,支持單機多 GPU 和多機多 GPU 訓(xùn)練,訓(xùn)練速度更快。在使用 DDP 之前,需要正確設(shè)置分布式環(huán)境。以下代碼展示了如何初始化和清理進程組:
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("gloo", rank=rank, world_size=world_size)
torch.manual_seed(42)
def cleanup():
dist.destroy_process_group()
將定義好的模型封裝到 DDP 中,以實現(xiàn)分布式訓(xùn)練:
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
setup(rank, world_size)
n = torch.cuda.device_count() // world_size
device_ids = list(range(rank * n, (rank + 1) * n))
model = ToyModel().to(device_ids[0])
ddp_model = DDP(model, device_ids=device_ids)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_ids[0])
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)
在分布式訓(xùn)練中,檢查點的保存和加載可以提高訓(xùn)練的可靠性和恢復(fù)能力。DDP 提供了優(yōu)化的方法來保存和加載模型:
import tempfile
def demo_checkpoint(rank, world_size):
setup(rank, world_size)
n = torch.cuda.device_count() // world_size
device_ids = list(range(rank * n, (rank + 1) * n))
model = ToyModel().to(device_ids[0])
ddp_model = DDP(model, device_ids=device_ids)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
dist.barrier()
rank0_devices = [x - rank * len(device_ids) for x in device_ids]
device_pairs = zip(rank0_devices, device_ids)
map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
ddp_model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=map_location))
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_ids[0])
loss_fn(outputs, labels).backward()
optimizer.step()
dist.barrier()
if rank == 0:
os.remove(CHECKPOINT_PATH)
cleanup()
DDP 還可以與多 GPU 模型結(jié)合使用,以支持更大的模型和更多的數(shù)據(jù):
class ToyMpModel(nn.Module):
def __init__(self, dev0, dev1):
super(ToyMpModel, self).__init__()
self.dev0 = dev0
self.dev1 = dev1
self.net1 = torch.nn.Linear(10, 10).to(dev0)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5).to(dev1)
def forward(self, x):
x = x.to(self.dev0)
x = self.relu(self.net1(x))
x = x.to(self.dev1)
return self.net2(x)
def demo_model_parallel(rank, world_size):
setup(rank, world_size)
dev0 = rank * 2
dev1 = rank * 2 + 1
mp_model = ToyMpModel(dev0, dev1)
ddp_mp_model = DDP(mp_model)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_mp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(dev1)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
本文詳細介紹了 PyTorch 中分布式數(shù)據(jù)并行的基本概念、實現(xiàn)方法和優(yōu)化技巧,包括 DDP 與 DataParallel 的比較、DDP 的基本用法、保存和加載檢查點的方法,以及 DDP 與模型并行的結(jié)合使用。通過這些內(nèi)容,您可以利用多 GPU 資源高效地加速深度學(xué)習(xí)模型的訓(xùn)練過程。未來,您可以進一步探索更復(fù)雜的分布式訓(xùn)練場景和優(yōu)化策略,以滿足更大規(guī)模模型和數(shù)據(jù)集的訓(xùn)練需求。編程獅將持續(xù)為您提供更多深度學(xué)習(xí)分布式訓(xùn)練的優(yōu)質(zhì)教程,助力您的技術(shù)成長之路。
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報電話:173-0602-2364|舉報郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: