PyTorch torch.onnx

2025-06-25 15:30 更新

一、什么是 ONNX?

ONNX(Open Neural Network Exchange)是一種開放的神經網絡交換格式,旨在促進不同深度學習框架之間的模型互操作性。通過將 PyTorch 模型導出為 ONNX 格式,我們可以在其他支持 ONNX 的框架和工具中使用這些模型,如 Caffe2、Microsoft ONNX Runtime 等。這對于模型的部署和優(yōu)化具有重要意義。

二、PyTorch 模型導出 ONNX 的基本流程

(一)示例:將預訓練的 AlexNet 導出到 ONNX

  1. 導入必要的庫

import torch
import torchvision

  1. 準備輸入和模型

## 創(chuàng)建一個虛擬輸入,形狀為 (10, 3, 224, 224),并將其移動到 GPU 上
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')


## 加載預訓練的 AlexNet 模型,并將其移動到 GPU 上
model = torchvision.models.alexnet(pretrained=True).cuda()


## 將模型設置為評估模式(非訓練模式)
model.eval()

  1. 定義輸入和輸出名稱

## 為模型的輸入和參數(shù)指定名稱,以提高模型圖的可讀性
input_names = ["actual_input_1"] + ["learned_%d" % i for i in range(16)]
output_names = ["output1"]

  1. 導出模型到 ONNX

## 將模型導出到 ONNX 文件 "alexnet.onnx"
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

(二)驗證導出的 ONNX 模型

  1. 使用 ONNX 庫驗證模型

import onnx


## 加載導出的 ONNX 模型
model = onnx.load("alexnet.onnx")


## 檢查模型的 IR 是否有效
onnx.checker.check_model(model)


## 打印模型圖的可讀表示形式
print(onnx.helper.printable_graph(model.graph))

  1. 使用 ONNX Runtime 運行模型

import onnxruntime as ort
import numpy as np


## 創(chuàng)建 ONNX Runtime 推理會話
ort_session = ort.InferenceSession('alexnet.onnx')


## 準備輸入數(shù)據(jù)
input_data = np.random.randn(10, 3, 224, 224).astype(np.float32)


## 運行模型并獲取輸出
outputs = ort_session.run(None, {'actual_input_1': input_data})


## 打印輸出結果
print(outputs[0])

三、跟蹤與腳本編寫

(一)基于跟蹤的導出器

基于跟蹤的導出器通過執(zhí)行一次模型并導出在此運行期間實際執(zhí)行的運算符來操作。這意味著如果模型是動態(tài)的(例如,根據(jù)輸入數(shù)據(jù)更改行為),導出可能不準確。同樣,跟蹤可能僅對特定的輸入大小有效。我們建議檢查模型跟蹤并確保所跟蹤的運算符看起來合理。

例如:

import torch


## 定義一個簡單的模型類
class LoopModel(torch.nn.Module):
    def forward(self, x, y):
        for i in range(y):
            x = x + i
        return x


model = LoopModel()
dummy_input = torch.ones(2, 3, dtype=torch.long)
loop_count = torch.tensor(5, dtype=torch.long)


## 使用基于跟蹤的導出器導出模型
torch.onnx.export(model, (dummy_input, loop_count), 'loop.onnx', verbose=True)

(二)基于腳本的導出器

基于腳本的導出器表示要導出的模型是 ScriptModule。ScriptModule 是 TorchScript 中的核心數(shù)據(jù)結構,TorchScript 是 Python 語言的子集,可用于從 PyTorch 代碼創(chuàng)建可序列化和可優(yōu)化的模型。

例如:

@torch.jit.script
def loop(x, y):
    for i in range(int(y)):
        x = x + i
    return x


class LoopModel2(torch.nn.Module):
    def forward(self, x, y):
        return loop(x, y)


model = LoopModel2()
dummy_input = torch.ones(2, 3, dtype=torch.long)
loop_count = torch.tensor(5, dtype=torch.long)


## 使用基于腳本的導出器導出模型
torch.onnx.export(model, (dummy_input, loop_count), 'loop.onnx', verbose=True, input_names=['input_data', 'loop_range'])

四、局限性和常見問題

  1. 張量就地索引分配不支持 :目前導出中不支持張量就地索引分配,如 data[index] = new_data??梢酝ㄟ^使用 scatter_ 運算符來解決此類問題。
  2. ONNX 中沒有張量列表的概念 :這使得導出消耗或產生張量列表的運算符變得困難,尤其是在導出時不知道張量列表的長度的情況下。
  3. 輸入大小固定問題 :如果模型應接受動態(tài)形狀的輸入,可以在導出 API 中使用參數(shù) dynamic_axes 來指定動態(tài)軸。
  4. 隱式標量數(shù)據(jù)類型轉換問題 :ONNX 不支持隱式標量數(shù)據(jù)類型轉換,但導出器會嘗試處理該部分。對于無法自動處理的情況,需要手動提供數(shù)據(jù)類型信息。

五、總結

通過本教程,我們詳細介紹了如何將 PyTorch 模型導出為 ONNX 格式,包括基本流程、跟蹤與腳本編寫的區(qū)別,以及一些局限性和常見問題的解決方案。將模型導出為 ONNX 格式可以提高模型的互操作性和部署靈活性,使我們能夠在各種支持 ONNX 的框架和工具中使用這些模型。

在實際應用中,根據(jù)模型的特點和需求,選擇合適的導出方式(基于跟蹤或基于腳本),并注意處理可能遇到的局限性和問題,可以確保模型成功導出并能夠在目標環(huán)境中正常運行。

以上內容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號