PyTorch 分布式 RPC 框架入門

2025-06-23 09:56 更新

在分布式深度學(xué)習(xí)領(lǐng)域,PyTorch 的 torch.distributed.rpc 包提供了一種靈活且強(qiáng)大的機(jī)制,用于構(gòu)建復(fù)雜的分布式應(yīng)用。通過遠(yuǎn)程過程調(diào)用(RPC)和遠(yuǎn)程引用(RRef),開發(fā)者可以輕松地在不同進(jìn)程中傳遞數(shù)據(jù)和方法調(diào)用,實現(xiàn)高效的分布式訓(xùn)練和推理。本文將通過詳細(xì)的代碼示例和深入的原理講解,幫助您快速掌握 PyTorch 分布式 RPC 框架的核心概念和應(yīng)用方法。

一、RPC 與 RRef:分布式通信的核心

(一)RPC 基礎(chǔ)

RPC(遠(yuǎn)程過程調(diào)用)允許一個進(jìn)程(客戶端)調(diào)用另一個進(jìn)程(服務(wù)器)中的函數(shù),就像調(diào)用本地函數(shù)一樣。在 PyTorch 的 torch.distributed.rpc 包中,RPC 提供了 rpc_syncrpc_async 兩種調(diào)用方式,分別用于阻塞式和非阻塞式通信。

import torch.distributed.rpc as rpc


## 阻塞式 RPC 調(diào)用
result = rpc.rpc_sync("dest_worker", torch.add, args=(torch.tensor(2), 3))


## 非阻塞式 RPC 調(diào)用
future = rpc.rpc_async("dest_worker", torch.add, args=(torch.tensor(2), 3))
result = future.wait()

(二)RRef:遠(yuǎn)程對象引用

RRef(Remote Reference)用于在分布式環(huán)境中引用遠(yuǎn)程對象。它允許開發(fā)者在不同進(jìn)程中共享和操作數(shù)據(jù),而不必?fù)?dān)心對象的物理位置。

from torch.distributed.rpc import RRef


## 創(chuàng)建遠(yuǎn)程對象
rref = rpc.remote("dest_worker", torch.randn, args=(3, 3))


## 獲取遠(yuǎn)程對象的值
value = rref.to_here()

二、分布式強(qiáng)化學(xué)習(xí)示例

(一)定義策略網(wǎng)絡(luò)

策略網(wǎng)絡(luò)是強(qiáng)化學(xué)習(xí)中的核心組件,用于根據(jù)當(dāng)前狀態(tài)選擇動作。

import torch.nn as nn
import torch.nn.functional as F


class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 2)


    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)

(二)實現(xiàn)觀察者和代理

觀察者負(fù)責(zé)與環(huán)境交互,代理負(fù)責(zé)根據(jù)觀察者收集的數(shù)據(jù)更新策略網(wǎng)絡(luò)。

import gym
import torch.distributed.rpc as rpc


class Observer:
    def __init__(self, rank):
        self.rank = rank
        self.env = gym.make('CartPole-v1')


    def run_episode(self, agent_rref, n_steps):
        state = self.env.reset()
        for _ in range(n_steps):
            action = rpc.rpc_sync(agent_rref.owner(), _call_method, args=(Agent.select_action, agent_rref, self.rank, state))
            next_state, reward, done, _ = self.env.step(action)
            rpc.rpc_sync(agent_rref.owner(), _call_method, args=(Agent.report_reward, agent_rref, self.rank, reward))
            state = next_state
            if done:
                break


class Agent:
    def __init__(self, world_size):
        self.ob_rrefs = []
        self.policy = Policy()
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=1e-2)


    def select_action(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        probs = self.policy(state)
        m = torch.distributions.Categorical(probs)
        action = m.sample()
        return action.item()


    def report_reward(self, reward):
        # 保存獎勵用于更新策略
        pass


    def run_episode(self, n_steps):
        # 觸發(fā)觀察者運行情節(jié)
        pass


    def finish_episode(self):
        # 更新策略網(wǎng)絡(luò)
        pass

(三)啟動分布式訓(xùn)練

import torch.multiprocessing as mp


def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    if rank == 0:
        rpc.init_rpc("agent", rank=rank, world_size=world_size)
        agent = Agent(world_size)
        for episode in range(100):
            agent.run_episode(n_steps=10)
            agent.finish_episode()
    else:
        rpc.init_rpc(f"observer_{rank}", rank=rank, world_size=world_size)
    rpc.shutdown()


mp.spawn(run_worker, args=(2,), nprocs=2, join=True)

三、分布式模型并行訓(xùn)練示例

(一)定義分布式 RNN 模型

import torch.nn as nn


class EmbeddingTable(nn.Module):
    def __init__(self, ntoken, ninp, dropout):
        super(EmbeddingTable, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp).cuda()


    def forward(self, input):
        return self.drop(self.encoder(input.cuda())).cpu()


class Decoder(nn.Module):
    def __init__(self, ntoken, nhid, dropout):
        super(Decoder, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.decoder = nn.Linear(nhid, ntoken)


    def forward(self, output):
        return self.decoder(self.drop(output))


class RNNModel(nn.Module):
    def __init__(self, ps, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.emb_table_rref = rpc.remote(ps, EmbeddingTable, args=(ntoken, ninp, dropout))
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder_rref = rpc.remote(ps, Decoder, args=(ntoken, nhid, dropout))


    def forward(self, input, hidden):
        emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
        output, hidden = self.rnn(emb, hidden)
        decoded = _remote_method(Decoder.forward, self.decoder_rref, output)
        return decoded, hidden

(二)實現(xiàn)分布式訓(xùn)練循環(huán)

from torch.distributed.autograd import context as dist_autograd
from torch.distributed.optim import DistributedOptimizer


def run_trainer():
    model = RNNModel('ps', ntoken=10, ninp=2, nhid=3, nlayers=4)
    opt = DistributedOptimizer(optim.SGD, model.parameter_rrefs(), lr=0.05)
    criterion = torch.nn.CrossEntropyLoss()


    def get_next_batch():
        for _ in range(5):
            data = torch.LongTensor(5, 3) % 10
            target = torch.LongTensor(5, 10) % 3
            yield data, target


    for epoch in range(10):
        for data, target in get_next_batch():
            with dist_autograd.context() as context_id:
                output, hidden = model(data, (torch.randn(4, 3, 3), torch.randn(4, 3, 3)))
                loss = criterion(output, target)
                dist_autograd.backward(context_id, [loss])
                opt.step(context_id)


def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    if rank == 1:
        rpc.init_rpc("trainer", rank=rank, world_size=world_size)
        run_trainer()
    else:
        rpc.init_rpc("ps", rank=rank, world_size=world_size)
    rpc.shutdown()


if __name__ == "__main__":
    mp.spawn(run_worker, args=(2,), nprocs=2, join=True)

四、總結(jié)與展望

通過本文的詳細(xì)講解,您已經(jīng)掌握了 PyTorch 分布式 RPC 框架的核心概念和應(yīng)用方法。RPC 和 RRef 提供了強(qiáng)大的工具,用于在分布式環(huán)境中構(gòu)建復(fù)雜的模型和訓(xùn)練流程。未來,您可以進(jìn)一步探索如何在實際項目中應(yīng)用這些技術(shù),以解決更大規(guī)模的模型訓(xùn)練和推理任務(wù)。編程獅將持續(xù)為您提供更多深度學(xué)習(xí)分布式計算的優(yōu)質(zhì)教程,助力您的技術(shù)成長。

以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號