PyTorch torch.utils.model_zoo

2025-07-02 17:36 更新

PyTorch 模型加載與遷移學(xué)習(xí):torch.hub 實踐指南

一、torch.hub 的簡介與優(yōu)勢

torch.hub 是 PyTorch 提供的一個方便的工具,用于加載和使用預(yù)訓(xùn)練模型。它簡化了模型的獲取過程,使得用戶可以輕松地從模型倉庫中下載并加載預(yù)訓(xùn)練模型。這對于快速上手深度學(xué)習(xí)項目、進行模型遷移學(xué)習(xí)以及復(fù)現(xiàn)研究結(jié)果非常有幫助。

二、模型加載函數(shù)詳解

(一)torch.hub.load

用于從 GitHub 倉庫加載模型或模型組件。您可以直接通過模型倉庫的路徑和模型名稱來獲取模型。

(二)torch.hub.load_state_dict_from_url

從指定的 URL 下載并加載模型的狀態(tài)字典。此函數(shù)非常有用,當(dāng)您需要直接從網(wǎng)絡(luò)加載模型權(quán)重時。

三、實際操作示例

(一)加載預(yù)訓(xùn)練模型

我們以加載預(yù)訓(xùn)練的 ResNet-18 模型為例,展示如何使用 torch.hub。

import torch


## 加載 ResNet-18 模型
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)


## 將模型設(shè)置為評估模式
model.eval()


## 打印模型結(jié)構(gòu)
print(model)

(二)加載模型狀態(tài)字典

如果您需要從 URL 加載模型的狀態(tài)字典,可以使用以下方法:

state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', map_location=torch.device('cpu'))


## 加載狀態(tài)字典到模型
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18')
model.load_state_dict(state_dict)

四、應(yīng)用場景與技巧

(一)遷移學(xué)習(xí)

torch.hub 在遷移學(xué)習(xí)中非常有用。您可以加載一個預(yù)訓(xùn)練模型,然后根據(jù)自己的任務(wù)需求進行微調(diào)。

## 加載預(yù)訓(xùn)練模型并進行微調(diào)
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)


## 替換模型的最后一層以適應(yīng)新的分類任務(wù)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 2)  # 假設(shè)新的分類任務(wù)有 2 個類別


## 將模型設(shè)置為訓(xùn)練模式
model.train()

(二)復(fù)現(xiàn)研究結(jié)果

在復(fù)現(xiàn)其他研究者的模型時,您可以直接使用 torch.hub 加載他們提供的預(yù)訓(xùn)練模型或模型組件。

## 假設(shè)有一個研究者提供的模型倉庫和模型名稱
## model = torch.hub.load('researcher/model_repo', 'model_name', source='github')

(三)使用不同版本的模型

torch.hub 支持加載特定版本的模型,這在模型更新后需要保持一致性時非常有用。

## 加載特定版本的模型
model = torch.hub.load('pytorch/vision:v0.5.0', 'resnet18', pretrained=True)

五、總結(jié)

通過本教程,我們詳細介紹了如何使用 torch.hub 加載和使用預(yù)訓(xùn)練模型,以及在遷移學(xué)習(xí)和復(fù)現(xiàn)研究結(jié)果中的應(yīng)用。torch.hub 提供了便捷的接口,使得模型的獲取和使用變得更加簡單。掌握這些技巧,可以幫助您更高效地進行深度學(xué)習(xí)開發(fā)和研究。

以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號