PyTorch torch.utils.checkpoint

2020-09-15 11:40 更新

原文: PyTorch torch.utils.checkpoint

注意

通過(guò)在反向過(guò)程中為每個(gè)檢查點(diǎn)段重新運(yùn)行一個(gè)正向通過(guò)段來(lái)實(shí)現(xiàn)檢查點(diǎn)。 這可能會(huì)導(dǎo)致像 RNG 狀態(tài)這樣的持久狀態(tài)比沒(méi)有檢查點(diǎn)的狀態(tài)更先進(jìn)。 默認(rèn)情況下,檢查點(diǎn)包括處理 RNG 狀態(tài)的邏輯,以便與非檢查點(diǎn)通過(guò)相比,使用 RNG(例如,通過(guò)丟棄)的檢查點(diǎn)通過(guò)具有確定的輸出。 根據(jù)檢查點(diǎn)操作的運(yùn)行時(shí)間,存儲(chǔ)和恢復(fù) RNG 狀態(tài)的邏輯可能會(huì)導(dǎo)致性能下降。 如果不需要與非檢查點(diǎn)通過(guò)相比確定的輸出,則在每個(gè)檢查點(diǎn)期間向checkpointcheckpoint_sequential提供preserve_rng_state=False,以忽略存儲(chǔ)和恢復(fù) RNG 狀態(tài)。

隱藏邏輯將當(dāng)前設(shè)備以及所有 cuda Tensor 參數(shù)的設(shè)備的 RNG 狀態(tài)保存并恢復(fù)到run_fn。 但是,該邏輯無(wú)法預(yù)料用戶是否將張量移動(dòng)到run_fn本身內(nèi)的新設(shè)備。 因此,如果在run_fn中將張量移動(dòng)到新設(shè)備(“新”表示不屬于[當(dāng)前設(shè)備+張量參數(shù)的設(shè)備的集合]),則與非檢查點(diǎn)傳遞相比,確定性輸出將永遠(yuǎn)無(wú)法保證。

torch.utils.checkpoint.checkpoint(function, *args, **kwargs)?

檢查點(diǎn)模型或模型的一部分

檢查點(diǎn)通過(guò)將計(jì)算交換為內(nèi)存來(lái)工作。 檢查點(diǎn)部分沒(méi)有存儲(chǔ)整個(gè)計(jì)算圖的所有中間激活以進(jìn)行向后計(jì)算,而是由而不是保存中間激活,而是在向后傳遞時(shí)重新計(jì)算它們。 它可以應(yīng)用于模型的任何部分。

具體而言,在前向傳遞中,function將以torch.no_grad()方式運(yùn)行,即不存儲(chǔ)中間激活。 相反,前向傳遞保存輸入元組和function參數(shù)。 在向后遍歷中,檢索保存的輸入和function,并再次在function上計(jì)算正向遍歷,現(xiàn)在跟蹤中間激活,然后使用這些激活值計(jì)算梯度。

警告

檢查點(diǎn)不適用于 torch.autograd.grad() ,而僅適用于 torch.autograd.backward()

Warning

如果后退期間的function調(diào)用與前退期間的調(diào)用有任何不同,例如,由于某些全局變量,則檢查點(diǎn)版本將不相等,很遺憾,無(wú)法檢測(cè)到該版本。

參數(shù)

  • 函數(shù) –描述在模型的正向傳遞中或模型的一部分中運(yùn)行的內(nèi)容。 它還應(yīng)該知道如何處理作為元組傳遞的輸入。 例如,在 LSTM 中,如果用戶通過(guò)(activation, hidden),則function應(yīng)正確使用第一個(gè)輸入作為activation,第二個(gè)輸入作為hidden
  • reserve_rng_state (bool , 可選 , 默認(rèn)= True 在每個(gè)檢查點(diǎn)期間恢復(fù) RNG 狀態(tài)。
  • args –包含function輸入的元組

退貨

*args上運(yùn)行function的輸出

torch.utils.checkpoint.checkpoint_sequential(functions, segments, *inputs, **kwargs)?

用于檢查點(diǎn)順序模型的輔助功能。

順序模型按順序(依次)執(zhí)行模塊/功能列表。 因此,我們可以將這樣的模型劃分為不同的段,并在每個(gè)段上檢查點(diǎn)。 除最后一個(gè)段外,所有段都將以torch.no_grad()方式運(yùn)行,即不存儲(chǔ)中間激活。 將保存每個(gè)檢查點(diǎn)線段的輸入,以便在后向傳遞中重新運(yùn)行該線段。

有關(guān)檢查點(diǎn)的工作方式,請(qǐng)參見(jiàn) checkpoint() 。

Warning

Checkpointing doesn't work with torch.autograd.grad(), but only with torch.autograd.backward().

Parameters

  • 功能 –一個(gè) torch.nn.Sequential 或要順序運(yùn)行的模塊或功能列表(包含模型)。
  • –在模型中創(chuàng)建的塊數(shù)
  • 輸入 –張量元組,它們是functions的輸入
  • preserve_rng_state (bool__, optional__, default=True) – Omit stashing and restoring the RNG state during each checkpoint.

Returns

*inputs上順序運(yùn)行functions的輸出

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)