PyTorch 入門

2025-06-18 14:04 更新

PyTorch 簡介

PyTorch 是一個開源的機器學(xué)習(xí)庫,它具有強大的 GPU 加速功能,方便用戶進(jìn)行深度學(xué)習(xí)模型的構(gòu)建和優(yōu)化,被廣泛應(yīng)用于計算機視覺、自然語言處理等領(lǐng)域。

PyTorch 的安裝

在開始學(xué)習(xí) PyTorch 之前,需要先安裝它。用戶可以通過編程獅(W3Cschool)網(wǎng)站查看詳細(xì)的安裝步驟,并根據(jù)自己的操作系統(tǒng)和 Python 版本來選擇合適的安裝方式。

PyTorch 的核心概念

  • 張量(Tensor) :類似于 NumPy 的多維數(shù)組,不同的是 PyTorch 的張量可以利用 GPU 加速計算。例如:

import torch


x = torch.tensor([1, 2, 3])  # 創(chuàng)建一個一維張量
y = torch.tensor([[4, 5], [6, 7]])  # 創(chuàng)建一個二維張量

  • 自動求導(dǎo)(Autograd) :PyTorch 的自動求導(dǎo)功能可以自動計算張量的梯度,這對于神經(jīng)網(wǎng)絡(luò)的訓(xùn)練至關(guān)重要。在張量上調(diào)用 .backward() 方法可以計算梯度;使用 torch.no_grad() 可以停止梯度追蹤。

x = torch.ones(2, 2, requires_grad=True)
y = x * 3
y.backward(torch.ones_like(x))
print(x.grad)

構(gòu)建神經(jīng)網(wǎng)絡(luò)

  • 定義網(wǎng)絡(luò)結(jié)構(gòu) :使用 torch.nn.Module 來定義神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)。例如,構(gòu)建一個簡單的全連接神經(jīng)網(wǎng)絡(luò):

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 10)  # 輸入層到隱藏層的連接,輸入維度為4,輸出維度為10
        self.fc2 = nn.Linear(10, 4)  # 隱藏層到輸出層的連接,輸入維度為10,輸出維度為4
        self.relu = nn.ReLU()  # 激活函數(shù)


    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


net = Net()

  • 定義損失函數(shù)和優(yōu)化器 :選擇合適的損失函數(shù)和優(yōu)化器對于模型的訓(xùn)練效果至關(guān)重要。常見的損失函數(shù)有均方誤差損失(MSELoss)、交叉熵?fù)p失(CrossEntropyLoss)等;常用的優(yōu)化器有隨機梯度下降(SGD)、Adam 等。

criterion = nn.MSELoss()  # 定義均方誤差損失函數(shù)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)  # 定義隨機梯度下降優(yōu)化器,學(xué)習(xí)率為0.01

  • 訓(xùn)練網(wǎng)絡(luò) :使用訓(xùn)練數(shù)據(jù)對網(wǎng)絡(luò)進(jìn)行訓(xùn)練,通過前向傳播計算輸出,計算損失函數(shù)的值,然后進(jìn)行反向傳播更新網(wǎng)絡(luò)的參數(shù)。

for epoch in range(1000):  # 訓(xùn)練1000個周期
    optimizer.zero_grad()  # 清空梯度
    output = net(x_data)  # 前向傳播
    loss = criterion(output, y_data)  # 計算損失
    loss.backward()  # 反向傳播
    optimizer.step()  # 更新參數(shù)


    if epoch % 100 == 99:  # 每100個周期打印一次損失
        print('Epoch: {}, Loss: {}'.format(epoch+1, loss.item()))

PyTorch 的優(yōu)勢

PyTorch 具有動態(tài)計算圖的特點,這使得它在網(wǎng)絡(luò)結(jié)構(gòu)的構(gòu)建和調(diào)試方面更加靈活方便。此外,PyTorch 的社區(qū)非?;钴S,有大量的開源項目和教程可供學(xué)習(xí)和參考,例如在編程獅(W3Cschool)網(wǎng)站上就有很多優(yōu)秀的 PyTorch 教程。

以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號