PyTorch 分布式數(shù)據(jù)并行入門

2025-06-19 10:16 更新

在深度學(xué)習(xí)模型的訓(xùn)練過程中,數(shù)據(jù)并行是一種常見的加速訓(xùn)練的方法。PyTorch 提供了 DistributedDataParallel(DDP)模塊,用于實現(xiàn)高效的多 GPU 分布式訓(xùn)練。DDP 通過使用 torch.distributed 包中的通信集合來同步梯度、參數(shù)和緩沖區(qū),支持在單機或跨多機的多進程環(huán)境中進行訓(xùn)練。本文將帶您深入了解 PyTorch 分布式數(shù)據(jù)并行的基本概念、實現(xiàn)方法和優(yōu)化技巧,幫助您高效地利用多 GPU 資源加速模型訓(xùn)練。

一、DDP 與 DataParallel 的比較

DDP 與 DataParallel 相比具有顯著的優(yōu)勢:

  1. 模型并行支持 :DDP 可與模型并行結(jié)合使用,而 DataParallel 不支持。當模型規(guī)模過大無法放入單個 GPU 時,DDP 能更好地處理這種情況。
  2. 多進程支持DataParallel 是單進程多線程,僅限單機;DDP 是多進程,支持單機多 GPU 和多機多 GPU 訓(xùn)練,訓(xùn)練速度更快。
  3. 性能優(yōu)化 :DDP 預(yù)先復(fù)制模型,避免了每次迭代的模型復(fù)制和全局解釋器鎖定,性能更優(yōu)。

二、DDP 基本用法

(一)環(huán)境設(shè)置

在使用 DDP 之前,需要正確設(shè)置分布式環(huán)境。以下代碼展示了如何初始化和清理進程組:

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    torch.manual_seed(42)


def cleanup():
    dist.destroy_process_group()

(二)定義模型并封裝 DDP

將定義好的模型封裝到 DDP 中,以實現(xiàn)分布式訓(xùn)練:

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)


    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    setup(rank, world_size)
    n = torch.cuda.device_count() // world_size
    device_ids = list(range(rank * n, (rank + 1) * n))
    model = ToyModel().to(device_ids[0])
    ddp_model = DDP(model, device_ids=device_ids)
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_ids[0])
    loss_fn(outputs, labels).backward()
    optimizer.step()
    cleanup()


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)

三、保存和加載檢查點

在分布式訓(xùn)練中,檢查點的保存和加載可以提高訓(xùn)練的可靠性和恢復(fù)能力。DDP 提供了優(yōu)化的方法來保存和加載模型:

import tempfile


def demo_checkpoint(rank, world_size):
    setup(rank, world_size)
    n = torch.cuda.device_count() // world_size
    device_ids = list(range(rank * n, (rank + 1) * n))
    model = ToyModel().to(device_ids[0])
    ddp_model = DDP(model, device_ids=device_ids)
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
    if rank == 0:
        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
    dist.barrier()
    rank0_devices = [x - rank * len(device_ids) for x in device_ids]
    device_pairs = zip(rank0_devices, device_ids)
    map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
    ddp_model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=map_location))
    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_ids[0])
    loss_fn(outputs, labels).backward()
    optimizer.step()
    dist.barrier()
    if rank == 0:
        os.remove(CHECKPOINT_PATH)
    cleanup()

四、DDP 與模型并行結(jié)合

DDP 還可以與多 GPU 模型結(jié)合使用,以支持更大的模型和更多的數(shù)據(jù):

class ToyMpModel(nn.Module):
    def __init__(self, dev0, dev1):
        super(ToyMpModel, self).__init__()
        self.dev0 = dev0
        self.dev1 = dev1
        self.net1 = torch.nn.Linear(10, 10).to(dev0)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5).to(dev1)


    def forward(self, x):
        x = x.to(self.dev0)
        x = self.relu(self.net1(x))
        x = x.to(self.dev1)
        return self.net2(x)


def demo_model_parallel(rank, world_size):
    setup(rank, world_size)
    dev0 = rank * 2
    dev1 = rank * 2 + 1
    mp_model = ToyMpModel(dev0, dev1)
    ddp_mp_model = DDP(mp_model)
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)
    optimizer.zero_grad()
    outputs = ddp_mp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(dev1)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    cleanup()

五、總結(jié)與展望

本文詳細介紹了 PyTorch 中分布式數(shù)據(jù)并行的基本概念、實現(xiàn)方法和優(yōu)化技巧,包括 DDP 與 DataParallel 的比較、DDP 的基本用法、保存和加載檢查點的方法,以及 DDP 與模型并行的結(jié)合使用。通過這些內(nèi)容,您可以利用多 GPU 資源高效地加速深度學(xué)習(xí)模型的訓(xùn)練過程。未來,您可以進一步探索更復(fù)雜的分布式訓(xùn)練場景和優(yōu)化策略,以滿足更大規(guī)模模型和數(shù)據(jù)集的訓(xùn)練需求。編程獅將持續(xù)為您提供更多深度學(xué)習(xí)分布式訓(xùn)練的優(yōu)質(zhì)教程,助力您的技術(shù)成長之路。

以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號