在pytorch的交叉熵?fù)p失函數(shù)的學(xué)習(xí)中,weight慘啊作為交叉熵函數(shù)對(duì)應(yīng)參數(shù)的輸入值,它的使用并不是想象中的那么簡(jiǎn)單。接下來(lái)的這篇文章小編就來(lái)詳細(xì)的介紹一下交叉熵?fù)p失函數(shù)的weight參數(shù)怎么使用吧。
首先
必須將權(quán)重也轉(zhuǎn)為Tensor的cuda格式;
然后
將該class_weight作為交叉熵函數(shù)對(duì)應(yīng)參數(shù)的輸入值。
補(bǔ)充:關(guān)于pytorch的CrossEntropyLoss的weight參數(shù)
首先這個(gè)weight參數(shù)比想象中的要考慮的多
你可以試試下面代碼
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)
這里的手動(dòng)計(jì)算是:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803
加權(quán)呢?
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)
手算發(fā)現(xiàn),并不是單純的那權(quán)重相乘:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113
而是
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075
發(fā)現(xiàn)了么,加權(quán)后,除以的是權(quán)重的和,不是數(shù)目的和。
我們?cè)衮?yàn)證一遍:
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)
手算:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
loss3 = 0 + ln(e2 + e0 + e0) = 2.2395
loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943
求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472
可能有人對(duì)loss的CE計(jì)算過(guò)程有疑問(wèn),我這里細(xì)致寫寫交叉熵的計(jì)算過(guò)程,就拿最后一個(gè)例子的loss4的計(jì)算說(shuō)明
小結(jié)
以上就是pytorch怎么使用交叉熵?fù)p失函數(shù)的全部?jī)?nèi)容,希望能給大家一個(gè)參考,也希望大家多多支持W3Cschool。