PyTorch torch.utils.cpp_extension

2025-07-02 14:27 更新

PyTorch C++ 擴展開發(fā)詳解

一、為什么需要 C++ 擴展?

在深度學習開發(fā)中,我們有時會遇到 Python 實現(xiàn)的性能瓶頸,或者需要調(diào)用現(xiàn)有的 C++_nlank 或 CUDA 代碼。PyTorch 提供了強大的 C++ 擴展功能,允許開發(fā)者將自定義的 C++ 或 CUDA 代碼集成到 PyTorch 模型中,從而提升性能或復(fù)用代碼。

二、PyTorch C++ 擴展的核心工具

(一)torch.utils.cpp_extension.CppExtension

用于創(chuàng)建一個 setuptools.Extension,用于構(gòu)建 C++ 擴展。

  • 參數(shù)
    • name:擴展的名稱。
    • sources:C++ 源文件的路徑列表。
    • 其他參數(shù)將傳遞給 setuptools.Extension 構(gòu)造函數(shù)。

  • 示例

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension


setup(
    name='extension',
    ext_modules=[
        CppExtension(
            name='extension',
            sources=['extension.cpp'],
            extra_compile_args=['-g']
        )
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)

(二)torch.utils.cpp_extension.CUDAExtension

用于創(chuàng)建一個 setuptools.Extension,用于構(gòu)建 CUDA/C++ 擴展。它會自動包含 CUDA 的相關(guān)路徑和庫。

  • 參數(shù)
    • name:擴展的名稱。
    • sources:C++ 和 CUDA 源文件的路徑列表。
    • 其他參數(shù)將傳遞給 setuptools.Extension 構(gòu)造函數(shù)。

  • 示例

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


setup(
    name='cuda_extension',
    ext_modules=[
        CUDAExtension(
            name='cuda_extension',
            sources=['extension.cpp', 'extension_kernel.cu'],
            extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
        )
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)

(三)torch.utils.cpp_extension.BuildExtension

自定義 setuptools 構(gòu)建擴展的類。它負責傳遞編譯器標志,并支持混合 C++/CUDA 編譯。

(四)torch.utils.cpp_extension.load

即時加載 PyTorch C++ 擴展(JIT)。它會編譯源代碼并將其作為模塊加載到當前 Python 進程中。

  • 參數(shù)
    • name:擴展的名稱。
    • sources:C++ 源文件的路徑列表。
    • extra_cflags:傳遞給編譯器的額外標志。
    • extra_cuda_cflags:傳遞給 nvcc 的額外標志。
    • extra_ldflags:傳遞給鏈接器的額外標志。
    • build_directory:構(gòu)建工作區(qū)的路徑。
    • verbose:是否開啟詳細日志。
    • with_cuda:是否包含 CUDA 支持。
    • is_python_module:是否作為 Python 模塊加載。

  • 示例

from torch.utils.cpp_extension import load


module = load(
    name='extension',
    sources=['extension.cpp', 'extension_kernel.cu'],
    extra_cflags=['-O2'],
    verbose=True
)

(五)torch.utils.cpp_extension.load_inline

從字符串源即時加載 PyTorch C++ 擴展(JIT)。它與 load 類似,但源代碼以字符串形式提供。

  • 參數(shù)
    • name:擴展的名稱。
    • cpp_sources:C++ 源代碼字符串或字符串列表。
    • cuda_sources:CUDA 源代碼字符串或字符串列表。
    • functions:需要生成綁定的函數(shù)名稱列表或字典。
    • 其他參數(shù)與 load 類似。

  • 示例

from torch.utils.cpp_extension import load_inline


source = '''
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
    return x.sin() + y.sin();
}
'''


module = load_inline(
    name='inline_extension',
    cpp_sources=[source],
    functions=['sin_add']
)

(六)輔助工具

  • torch.utils.cpp_extension.include_paths(cuda=False):獲取構(gòu)建 C++ 或 CUDA 擴展所需的包含路徑。
  • torch.utils.cpp_extension.check_compiler_abi_compatibility(compiler):驗證編譯器是否與 PyTorch 兼容。
  • torch.utils.cpp_extension.verify_ninja_availability():檢查系統(tǒng)上是否有 ninja 構(gòu)建系統(tǒng)。

三、開發(fā)和構(gòu)建 C++ 擴展的步驟

(一)準備源代碼

編寫 C++ 或 CUDA 源代碼,實現(xiàn)自定義的功能或操作。

(二)編寫 setup 腳本

創(chuàng)建一個 setup.py 腳本,使用 CppExtensionCUDAExtension 定義擴展,并使用 BuildExtension 構(gòu)建擴展。

(三)構(gòu)建和安裝擴展

運行以下命令構(gòu)建和安裝擴展:

python setup.py install

(四)即時加載擴展

使用 loadload_inline 函數(shù)即時加載擴展,無需編寫 setup.py 腳本。

四、總結(jié)

通過本教程,我們詳細介紹了 PyTorch C++ 擴展的開發(fā)工具和流程。利用這些工具,開發(fā)者可以輕松地將自定義的 C++ 或 CUDA 代碼集成到 PyTorch 項目中,從而提升模型性能或復(fù)用現(xiàn)有代碼。掌握 C++ 擴展的開發(fā)方法,能夠幫助您在深度學習項目中更好地優(yōu)化性能和利用硬件資源。

以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號