W3Cschool
恭喜您成為首批注冊(cè)用戶
獲得88經(jīng)驗(yàn)值獎(jiǎng)勵(lì)
注意
通過(guò)在反向過(guò)程中為每個(gè)檢查點(diǎn)段重新運(yùn)行一個(gè)正向通過(guò)段來(lái)實(shí)現(xiàn)檢查點(diǎn)。 這可能會(huì)導(dǎo)致像 RNG 狀態(tài)這樣的持久狀態(tài)比沒(méi)有檢查點(diǎn)的狀態(tài)更先進(jìn)。 默認(rèn)情況下,檢查點(diǎn)包括處理 RNG 狀態(tài)的邏輯,以便與非檢查點(diǎn)通過(guò)相比,使用 RNG(例如,通過(guò)丟棄)的檢查點(diǎn)通過(guò)具有確定的輸出。 根據(jù)檢查點(diǎn)操作的運(yùn)行時(shí)間,存儲(chǔ)和恢復(fù) RNG 狀態(tài)的邏輯可能會(huì)導(dǎo)致性能下降。 如果不需要與非檢查點(diǎn)通過(guò)相比確定的輸出,則在每個(gè)檢查點(diǎn)期間向checkpoint
或checkpoint_sequential
提供preserve_rng_state=False
,以忽略存儲(chǔ)和恢復(fù) RNG 狀態(tài)。
隱藏邏輯將當(dāng)前設(shè)備以及所有 cuda Tensor 參數(shù)的設(shè)備的 RNG 狀態(tài)保存并恢復(fù)到run_fn
。 但是,該邏輯無(wú)法預(yù)料用戶是否將張量移動(dòng)到run_fn
本身內(nèi)的新設(shè)備。 因此,如果在run_fn
中將張量移動(dòng)到新設(shè)備(“新”表示不屬于[當(dāng)前設(shè)備+張量參數(shù)的設(shè)備的集合]),則與非檢查點(diǎn)傳遞相比,確定性輸出將永遠(yuǎn)無(wú)法保證。
torch.utils.checkpoint.checkpoint(function, *args, **kwargs)?
檢查點(diǎn)模型或模型的一部分
檢查點(diǎn)通過(guò)將計(jì)算交換為內(nèi)存來(lái)工作。 檢查點(diǎn)部分沒(méi)有存儲(chǔ)整個(gè)計(jì)算圖的所有中間激活以進(jìn)行向后計(jì)算,而是由而不是保存中間激活,而是在向后傳遞時(shí)重新計(jì)算它們。 它可以應(yīng)用于模型的任何部分。
具體而言,在前向傳遞中,function
將以torch.no_grad()
方式運(yùn)行,即不存儲(chǔ)中間激活。 相反,前向傳遞保存輸入元組和function
參數(shù)。 在向后遍歷中,檢索保存的輸入和function
,并再次在function
上計(jì)算正向遍歷,現(xiàn)在跟蹤中間激活,然后使用這些激活值計(jì)算梯度。
警告
檢查點(diǎn)不適用于 torch.autograd.grad()
,而僅適用于 torch.autograd.backward()
。
Warning
如果后退期間的function
調(diào)用與前退期間的調(diào)用有任何不同,例如,由于某些全局變量,則檢查點(diǎn)版本將不相等,很遺憾,無(wú)法檢測(cè)到該版本。
參數(shù)
(activation, hidden)
,則function
應(yīng)正確使用第一個(gè)輸入作為activation
,第二個(gè)輸入作為hidden
function
輸入的元組退貨
在*args
上運(yùn)行function
的輸出
torch.utils.checkpoint.checkpoint_sequential(functions, segments, *inputs, **kwargs)?
用于檢查點(diǎn)順序模型的輔助功能。
順序模型按順序(依次)執(zhí)行模塊/功能列表。 因此,我們可以將這樣的模型劃分為不同的段,并在每個(gè)段上檢查點(diǎn)。 除最后一個(gè)段外,所有段都將以torch.no_grad()
方式運(yùn)行,即不存儲(chǔ)中間激活。 將保存每個(gè)檢查點(diǎn)線段的輸入,以便在后向傳遞中重新運(yùn)行該線段。
有關(guān)檢查點(diǎn)的工作方式,請(qǐng)參見(jiàn) checkpoint()
。
Warning
Checkpointing doesn't work with torch.autograd.grad()
, but only with torch.autograd.backward()
.
Parameters
torch.nn.Sequential
或要順序運(yùn)行的模塊或功能列表(包含模型)。functions
的輸入Returns
在*inputs
上順序運(yùn)行functions
的輸出
例
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
Copyright©2021 w3cschool編程獅|閩ICP備15016281號(hào)-3|閩公網(wǎng)安備35020302033924號(hào)
違法和不良信息舉報(bào)電話:173-0602-2364|舉報(bào)郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號(hào)
聯(lián)系方式:
更多建議: