PyTorch torch.hub

2020-09-15 10:12 更新

原文:PyTorch torch.hub

Pytorch Hub 是經(jīng)過預先訓練的模型資料庫,旨在促進研究的可重復性。

發(fā)布模型

Pytorch Hub 支持通過添加簡單的hubconf.py文件將預訓練的模型(模型定義和預訓練的權重)發(fā)布到 github 存儲庫;

hubconf.py可以有多個入口點。 每個入口點都定義為 python 函數(shù)(例如:您要發(fā)布的經(jīng)過預先訓練的模型)。

def entrypoint_name(*args, **kwargs):
    # args & kwargs are optional, for models which take positional/keyword arguments.
    ...

如何實現(xiàn)入口點?

如果我們擴展pytorch/vision/hubconf.py中的實現(xiàn),則以下代碼段指定了resnet18模型的入口點。 在大多數(shù)情況下,在hubconf.py中導入正確的功能就足夠了。 在這里,我們僅以擴展版本為例來說明其工作原理。

dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18


## resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
    """ # This docstring shows up in hub.help()
    Resnet18 model
    pretrained (bool): kwargs, load pretrained weights into the model
    """
    # Call the model, load pretrained weights
    model = _resnet18(pretrained=pretrained, **kwargs)
    return model

  • dependencies變量是加載模型所需的軟件包名稱的列表。 請注意,這可能與訓練模型所需的依賴項稍有不同。
  • argskwargs傳遞給實際的可調用函數(shù)。
  • 該函數(shù)的文檔字符串用作幫助消息。 它解釋了模型做什么以及允許的位置/關鍵字參數(shù)是什么。 強烈建議在此處添加一些示例。
  • Entrypoint 函數(shù)可以返回模型(nn.module),也可以返回輔助工具以使用戶工作流程更流暢,例如 標記器。
  • 帶下劃線前綴的可調用項被視為輔助功能,不會在torch.hub.list()中顯示。
  • 預訓練的權重既可以存儲在 github 存儲庫中,也可以由torch.hub.load_state_dict_from_url()加載。 如果少于 2GB,建議將其附加到項目版本,并使用該版本中的網(wǎng)址。 在上面的示例中,torchvision.models.resnet.resnet18處理pretrained,或者,您可以在入口點定義中添加以下邏輯。

if pretrained:
    # For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
    dirname = os.path.dirname(__file__)
    checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)


    # For checkpoint saved elsewhere
    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

重要通知

  • 發(fā)布的模型應至少在分支/標簽中。 不能是隨機提交。

從集線器加載模型

Pytorch Hub 提供了便捷的 API,可通過torch.hub.list()瀏覽集線器中的所有可用模型,通過torch.hub.help()顯示文檔字符串和示例,并使用torch.hub.load()加載經(jīng)過預先訓練的模型

torch.hub.list(github, force_reload=False)?

列出 <cite>github</cite> hubconf 中可用的所有入口點。

參數(shù)

  • github (字符串)–格式為“ repo_owner / repo_name [:tag_name]”的字符串,帶有可選的標記/分支。 如果未指定,則默認分支為<cite>主站</cite>。 示例:“ pytorch / vision [:hub]”
  • force_reload (bool , 可選)–是否放棄現(xiàn)有緩存并強制重新下載。 默認值為<cite>否</cite>。

退貨

可用入口點名稱的列表

返回類型

入口點

>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)

torch.hub.help(github, model, force_reload=False)?

顯示入口點<cite>模型</cite>的文檔字符串。

Parameters

  • github (字符串)–格式為< repo_owner / repo_name [:tag_name] [:HT_7]的字符串,帶有可選的標記/分支。 如果未指定,則默認分支為<cite>主站</cite>。 示例:“ pytorch / vision [:hub]”
  • 模型(字符串)–在存儲庫的 hubconf.py 中定義的入口點名稱字符串
  • force_reload (bool__, optional) – whether to discard the existing cache and force a fresh download. Default is <cite>False</cite>.

Example

>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))

torch.hub.load(github, model, *args, **kwargs)?

使用預訓練的權重從 github 存儲庫加載模型。

Parameters

  • github (string) – a string with format “repo_owner/repo_name[:tag_name]” with an optional tag/branch. The default branch is <cite>master</cite> if not specified. Example: 'pytorch/vision[:hub]'
  • model (string) – a string of entrypoint name defined in repo's hubconf.py
  • args (可選*)–可調用<cite>模型</cite>的相應 args。
  • force_reload (bool , 可選)–是否無條件強制重新下載 github 存儲庫。 默認值為<cite>否</cite>。
  • 詳細 (bool 可選)–如果為 False,則忽略有關命中本地緩存的消息。 請注意,有關首次下載的消息不能被靜音。 默認值為<cite>為真</cite>。
  • \ kwargs (可選)–可調用<cite>模型</cite>的相應 kwargs。

Returns

具有相應預訓練權重的單個模型。

Example

>>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)

torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)?

將給定 URL 上的對象下載到本地路徑。

Parameters

  • url (字符串)–要下載的對象的 URL
  • dst (字符串)–保存對象的完整路徑,例如 <cite>/ tmp / temporary_file</cite>
  • hash_prefix (字符串 , 可選))–如果不是 None,則下載的 SHA256 文件應以 <cite>hash_prefix</cite> 開頭。 默認值:無
  • 進度 (bool 可選)–是否顯示 stderr 的進度條默認值:True

Example

>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')

torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False)?

將 Torch 序列化對象加載到給定的 URL。

如果下載的文件是 zip 文件,它將被自動解壓縮。

如果 <cite>model_dir</cite> 中已經(jīng)存在該對象,則將其反序列化并返回。 <cite>model_dir</cite> 的默認值為$TORCH_HOME/checkpoints,其中環(huán)境變量$TORCH_HOME的默認值為$XDG_CACHE_HOME/torch$XDG_CACHE_HOME遵循 Linux 文件系統(tǒng)布局的 X 設計組規(guī)范,如果未設置,則默認值為~/.cache。

Parameters

  • url (string) – URL of the object to download
  • model_dir (字符串 , 可選)–保存對象的目錄
  • map_location (可選)–指定如何重新映射存儲位置的函數(shù)或命令(請參見 torch.load)
  • 進度 (bool 可選)–是否顯示 stderr 進度條。 默認值:True
  • check_hash (bool , 可選)–如果為 True,則 URL 的文件名部分應遵循命名約定filename-<sha256>.ext,其中[ <sha256>是文件內(nèi)容的 SHA256 哈希值的前 8 位或更多位。 哈希用于確保唯一的名稱并驗證文件的內(nèi)容。 默認值:False

Example

>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

運行加載的模型:

注意,torch.load()中的*args, **kwargs用于實例化模型。 加載模型后,如何找到可以使用該模型的功能? 建議的工作流程是

  • dir(model)查看模型的所有可用方法。
  • help(model.foo)檢查model.foo需要執(zhí)行哪些參數(shù)

為了幫助用戶探索而又不來回參考文檔,我們強烈建議回購所有者使功能幫助消息清晰明了。 包含一個最小的工作示例也很有幫助。

我下載的模型保存在哪里?

這些位置按以下順序使用

  • 呼叫hub.set_dir(<PATH_TO_HUB_DIR>)
  • $TORCH_HOME/hub,如果設置了環(huán)境變量TORCH_HOME。
  • $XDG_CACHE_HOME/torch/hub,如果設置了環(huán)境變量XDG_CACHE_HOME。
  • ~/.cache/torch/hub

torch.hub.set_dir(d)?

(可選)將 hub_dir 設置為本地目錄,以保存下載的模型&權重。

如果未調用set_dir,則默認路徑為$TORCH_HOME/hub,其中環(huán)境變量$TORCH_HOME默認為$XDG_CACHE_HOME/torch。 $XDG_CACHE_HOME遵循 Linux 文件系統(tǒng)布局的 X 設計組規(guī)范,如果未設置環(huán)境變量,則默認值為~/.cache。

Parameters

d (字符串)–本地文件夾的路徑,用于保存下載的模型&權重。

緩存邏輯

默認情況下,加載文件后我們不會清理文件。 如果hub_dir中已經(jīng)存在,則集線器默認使用緩存。

用戶可以通過調用hub.load(..., force_reload=True)來強制重新加載。 這將刪除現(xiàn)有的 github 文件夾和下載的權重,重新初始化新的下載。 當更新發(fā)布到同一分支時,此功能很有用,用戶可以跟上最新版本。

已知限制:

Torch 集線器通過導入軟件包來進行工作,就像安裝軟件包一樣。 在 Python 中導入會帶來一些副作用。 例如,您可以在 Python 緩存sys.modulessys.path_importer_cache中看到新項目,這是正常的 Python 行為。

在這里值得一提的已知限制是用戶無法相同的 python 進程中加載同一存儲庫的兩個不同分支。 就像在 Python 中安裝兩個具有相同名稱的軟件包一樣,這是不好的。 快取可能會加入聚會,如果您實際嘗試的話會給您帶來驚喜。 當然,將它們分別加載是完全可以的。

以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號