PyTorch torch隨機

2025-07-02 11:30 更新

PyTorch 隨機數(shù)生成詳解:從基礎(chǔ)到進階

一、PyTorch 隨機數(shù)生成簡介

在深度學習中,隨機數(shù)生成是一個重要的環(huán)節(jié),尤其是在模型訓練和數(shù)據(jù)預處理過程中。PyTorch 提供了豐富的隨機數(shù)生成函數(shù),可以幫助我們控制隨機數(shù)生成過程,確保實驗的可重復性。本教程將深入淺出地講解 PyTorch 中的隨機數(shù)生成函數(shù),并提供實際的代碼示例,幫助您輕松掌握這些函數(shù)的使用。

二、PyTorch 隨機數(shù)生成函數(shù)詳解

(一)torch.random.manual_seed(seed)

這個函數(shù)用于設(shè)置隨機數(shù)生成的種子。通過設(shè)定一個固定的種子值,可以確保每次運行代碼時生成的隨機數(shù)序列相同,從而保證實驗結(jié)果的可重復性。

參數(shù)說明

  • seedpython:int):所需的種子值。

示例代碼

import torch


## 設(shè)置隨機種子
torch.random.manual_seed(42)


## 生成隨機數(shù)
random_tensor = torch.rand(3, 3)
print(random_tensor)

(二)torch.random.seed()

torch.random.seed() 函數(shù)將隨機數(shù)生成的種子設(shè)置為一個不確定的隨機數(shù)。這通常用于希望每次運行代碼時都生成不同的隨機數(shù)序列的場景。

返回值

返回用于播種 RNG 的 64 位數(shù)字。

(三)torch.random.initial_seed()

這個函數(shù)返回用于生成隨機數(shù)的初始種子。這個種子是在 PyTorch 初始化時設(shè)置的默認種子。

返回值

返回一個 Python long 類型的值,表示初始種子。

(四)torch.random.get_rng_state()

torch.random.get_rng_state() 函數(shù)以 torch.ByteTensor 的形式返回隨機數(shù)生成器(RNG)的狀態(tài)。這個狀態(tài)包含了生成隨機數(shù)序列所需的所有信息。

返回值

返回一個 torch.ByteTensor,表示 RNG 的當前狀態(tài)。

(五)torch.random.set_rng_state(new_state)

這個函數(shù)用于設(shè)置隨機數(shù)生成器的狀態(tài)。通過設(shè)置 RNG 的狀態(tài),可以恢復到之前保存的狀態(tài),從而繼續(xù)生成特定的隨機數(shù)序列。

參數(shù)說明

  • new_statetorch.ByteTensor):所需的 RNG 狀態(tài)。

(六)torch.random.fork_rng(devices=None, enabled=True)

torch.random.fork_rng() 函數(shù)分叉 RNG,以便在您返回時將 RNG 重置為之前的狀態(tài)。這在需要臨時改變隨機數(shù)生成狀態(tài)的場景中非常有用,例如在并行計算中。

參數(shù)說明

  • devices(可迭代的 CUDA ID 的列表):指定要分叉 RNG 的 CUDA 設(shè)備。CPU RNG 狀態(tài)始終會被分叉。如果您的計算機上有多個設(shè)備,明確指定設(shè)備可以避免警告。
  • enabledbool):如果設(shè)置為 False,則不分叉 RNG。這在需要禁用上下文管理器時非常方便,無需刪除代碼并調(diào)整縮進。

示例代碼

with torch.random.fork_rng(devices=[0, 1], enabled=True):
    # 在這個上下文中,RNG 被分叉
    random_tensor_in_fork = torch.rand(3, 3)
    print(random_tensor_in_fork)


## 退出上下文后,RNG 恢復到之前的狀態(tài)
random_tensor_after_fork = torch.rand(3, 3)
print(random_tensor_after_fork)

三、綜合示例:控制隨機數(shù)生成在深度學習中的應用

(一)確保實驗可重復性

在深度學習實驗中,我們通常希望通過設(shè)置隨機種子來確保實驗的可重復性。以下是一個完整的示例,展示了如何在訓練模型時控制隨機數(shù)生成:

import torch
import torch.nn as nn
import torch.optim as optim


## 設(shè)置隨機種子以確??芍貜托?torch.random.manual_seed(42)


## 定義一個簡單的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 2)


    def forward(self, x):
        return self.linear(x)


## 創(chuàng)建模型、優(yōu)化器和損失函數(shù)
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()


## 生成隨機數(shù)據(jù)
inputs = torch.randn(100, 10)
targets = torch.randn(100, 2)


## 訓練模型
for epoch in range(10):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()


    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

(二)在數(shù)據(jù)加載中使用隨機數(shù)生成

在加載數(shù)據(jù)集時,我們通常會使用隨機數(shù)生成來隨機打亂數(shù)據(jù)順序。以下是一個示例:

from torch.utils.data import DataLoader, TensorDataset


## 設(shè)置隨機種子
torch.random.manual_seed(42)


## 創(chuàng)建數(shù)據(jù)集
inputs = torch.randn(100, 10)
targets = torch.randn(100, 2)
dataset = TensorDataset(inputs, targets)


## 創(chuàng)建數(shù)據(jù)加載器并隨機打亂數(shù)據(jù)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)


## 遍歷數(shù)據(jù)加載器
for batch in dataloader:
    batch_inputs, batch_targets = batch
    # 在這里進行訓練步驟

四、總結(jié)與最佳實踐

通過本教程,我們詳細介紹了 PyTorch 中的隨機數(shù)生成函數(shù)及其在深度學習中的應用。正確控制隨機數(shù)生成對于確保實驗的可重復性至關(guān)重要。以下是一些最佳實踐建議:

  1. 在實驗開始時設(shè)置隨機種子(使用 torch.random.manual_seed()),以確保結(jié)果的可重復性。
  2. 如果需要臨時改變隨機數(shù)生成狀態(tài),可以使用 torch.random.fork_rng() 來分叉 RNG,并在完成后恢復到之前的狀態(tài)。
  3. 在分布式訓練或多 GPU 訓練中,注意為每個設(shè)備正確設(shè)置和管理 RNG 狀態(tài),以避免隨機數(shù)序列的混亂。
  4. 理解不同隨機數(shù)生成函數(shù)的用途,根據(jù)實際需求選擇合適的函數(shù)來控制隨機性。
以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號