App下載

Pytorch半精度網(wǎng)絡(luò)訓(xùn)練需要注意哪些問(wèn)題?

猿友 2021-07-15 14:05:44 瀏覽數(shù) (4275)
反饋

為了提高pytorch的模型訓(xùn)練的效率,我們有時(shí)候會(huì)選擇放棄部分精度來(lái)?yè)Q取運(yùn)算耗時(shí)的縮減。也就是說(shuō),在對(duì)精度要求不是那么高的情況下我們可以使用pytorch半精度網(wǎng)絡(luò)訓(xùn)練。但是在使用pytorch進(jìn)行半精度網(wǎng)絡(luò)訓(xùn)練的時(shí)候可能會(huì)出現(xiàn)一些問(wèn)題,小編將這些問(wèn)題進(jìn)行了一個(gè)總結(jié),各位小伙伴可以進(jìn)行參考。

用Pytorch1.0進(jìn)行半精度浮點(diǎn)型網(wǎng)絡(luò)訓(xùn)練需要注意下問(wèn)題:

1、網(wǎng)絡(luò)要在GPU上跑,模型和輸入樣本數(shù)據(jù)都要cuda().half()

2、模型參數(shù)轉(zhuǎn)換為half型,不必索引到每層,直接model.cuda().half()即可

3、對(duì)于半精度模型,優(yōu)化算法,Adam我在使用過(guò)程中,在某些參數(shù)的梯度為0的時(shí)候,更新權(quán)重后,梯度為零的權(quán)重變成了NAN,這非常奇怪,但是Adam算法對(duì)于全精度數(shù)據(jù)類型卻沒有這個(gè)問(wèn)題。

另外,SGD算法對(duì)于半精度和全精度計(jì)算均沒有問(wèn)題。

還有一個(gè)問(wèn)題是不知道是不是網(wǎng)絡(luò)結(jié)構(gòu)比較小的原因,使用半精度的訓(xùn)練速度還沒有全精度快。這個(gè)值得后續(xù)進(jìn)一步探索。

對(duì)于上面的這個(gè)問(wèn)題,的確是網(wǎng)絡(luò)很小的情況下,在1080Ti上半精度浮點(diǎn)型沒有很明顯的優(yōu)勢(shì),但是當(dāng)網(wǎng)絡(luò)變大之后,半精度浮點(diǎn)型要比全精度浮點(diǎn)型要快。

但具體快多少和模型的大小以及輸入樣本大小有關(guān)系,我測(cè)試的是要快1/6,同時(shí),半精度浮點(diǎn)型在占用內(nèi)存上比較有優(yōu)勢(shì),對(duì)于精度的影響尚未探究。

將網(wǎng)絡(luò)再變大些,epoch的次數(shù)也增大,半精度和全精度的時(shí)間差就表現(xiàn)出來(lái)了,在訓(xùn)練的時(shí)候。

補(bǔ)充:pytorch半精度,混合精度,單精度訓(xùn)練的區(qū)別amp.initialize

看代碼吧~

mixed_precision = True
try:  # Mixed precision training https://github.com/NVIDIA/apex
    from apex import amp
except:
    mixed_precision = False  # not installed

 model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=1)

為了幫助提高Pytorch的訓(xùn)練效率,英偉達(dá)提供了混合精度訓(xùn)練工具Apex。號(hào)稱能夠在不降低性能的情況下,將模型訓(xùn)練的速度提升2-4倍,訓(xùn)練顯存消耗減少為之前的一半。

文檔地址是:https://nvidia.github.io/apex/index.html

該 工具 提供了三個(gè)功能,amp、parallel和normalization。由于目前該工具還是0.1版本,功能還是很基礎(chǔ)的,在最后一個(gè)normalization功能中只提供了LayerNorm層的復(fù)現(xiàn),實(shí)際上在后續(xù)的使用過(guò)程中會(huì)發(fā)現(xiàn),出現(xiàn)問(wèn)題最多的是pytorch的BN層。

第二個(gè)工具是pytorch的分布式訓(xùn)練的復(fù)現(xiàn),在文檔中描述的是和pytorch中的實(shí)現(xiàn)等價(jià),在代碼中可以選擇任意一個(gè)使用,實(shí)際使用過(guò)程中發(fā)現(xiàn),在使用混合精度訓(xùn)練時(shí),使用Apex復(fù)現(xiàn)的parallel工具,能避免一些bug。

默認(rèn)訓(xùn)練方式是 單精度f(wàn)loat32

import torch
model = torch.nn.Linear(D_in, D_out)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
 out = model(img)
 loss = LOSS(out, label)
 loss.backward()
 optimizer.step()
 optimizer.zero_grad()

半精度 model(img.half())


接下來(lái)是混合精度的實(shí)現(xiàn),這里主要用到Apex的amp工具。

代碼修改為:

加上這一句封裝,

model, optimizer = amp.initialize(model, optimizer, opt_level=“O1”)
import torch
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

for img, label in dataloader:
 out = model(img)
 loss = LOSS(out, label)
 # loss.backward()
 with amp.scale_loss(loss, optimizer) as scaled_loss:
     scaled_loss.backward()

 optimizer.step()
 optimizer.zero_grad()

實(shí)際流程為:調(diào)用amp.initialize按照預(yù)定的opt_level對(duì)model和optimizer進(jìn)行設(shè)置。在計(jì)算loss時(shí)使用amp.scale_loss進(jìn)行回傳。

需要注意以下幾點(diǎn):

在調(diào)用amp.initialize之前,模型需要放在GPU上,也就是需要調(diào)用cuda()或者to()。

在調(diào)用amp.initialize之前,模型不能調(diào)用任何分布式設(shè)置函數(shù)。

此時(shí)輸入數(shù)據(jù)不需要在轉(zhuǎn)換為半精度。

在使用混合精度進(jìn)行計(jì)算時(shí),最關(guān)鍵的參數(shù)是opt_level。他一共含有四種設(shè)置值:‘00',‘01',‘02',‘03'。實(shí)際上整個(gè)amp.initialize的輸入?yún)?shù)很多:

但是在實(shí)際使用過(guò)程中發(fā)現(xiàn),設(shè)置opt_level即可,這也是文檔中例子的使用方法,甚至在不同的opt_level設(shè)置條件下,其他的參數(shù)會(huì)變成無(wú)效。(已知BUG:使用‘01'時(shí)設(shè)置keep_batchnorm_fp32的值會(huì)報(bào)錯(cuò))

概括起來(lái):

00相當(dāng)于原始的單精度訓(xùn)練。01在大部分計(jì)算時(shí)采用半精度,但是所有的模型參數(shù)依然保持單精度,對(duì)于少數(shù)單精度較好的計(jì)算(如softmax)依然保持單精度。02相比于01,將模型參數(shù)也變?yōu)榘刖取?/p>

03基本等于最開始實(shí)驗(yàn)的全半精度的運(yùn)算。值得一提的是,不論在優(yōu)化過(guò)程中,模型是否采用半精度,保存下來(lái)的模型均為單精度模型,能夠保證模型在其他應(yīng)用中的正常使用。這也是Apex的一大賣點(diǎn)。

在Pytorch中,BN層分為train和eval兩種操作。

實(shí)現(xiàn)時(shí)若為單精度網(wǎng)絡(luò),會(huì)調(diào)用CUDNN進(jìn)行計(jì)算加速。常規(guī)訓(xùn)練過(guò)程中BN層會(huì)被設(shè)為train。Apex優(yōu)化了這種情況,通過(guò)設(shè)置keep_batchnorm_fp32參數(shù),能夠保證此時(shí)BN層使用CUDNN進(jìn)行計(jì)算,達(dá)到最好的計(jì)算速度。

但是在一些fine tunning場(chǎng)景下,BN層會(huì)被設(shè)為eval(我的模型就是這種情況)。此時(shí)keep_batchnorm_fp32的設(shè)置并不起作用,訓(xùn)練會(huì)產(chǎn)生數(shù)據(jù)類型不正確的bug。此時(shí)需要人為的將所有BN層設(shè)置為半精度,這樣將不能使用CUDNN加速。

一個(gè)設(shè)置的參考代碼如下:

def fix_bn(m):
 classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
     m.eval().half()

model.apply(fix_bn)

實(shí)際測(cè)試下來(lái),最后的模型準(zhǔn)確度上感覺差別不大,可能有輕微下降;時(shí)間上變化不大,這可能會(huì)因不同的模型有差別;顯存開銷上確實(shí)有很大的降低。

小結(jié)

通過(guò)設(shè)置pytorch半精度網(wǎng)絡(luò)訓(xùn)練,可以提高pytorch的訓(xùn)練效率。以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持W3Cschool。

0 人點(diǎn)贊