docs के अनुसार, CrossEntropyLoss मानदंड संयुक्त LogSoftmax फ़ंक्शन और NLLLoss मानदंड।

यह सब ठीक है और ठीक है, लेकिन इसका परीक्षण इस दावे को प्रमाणित नहीं करता है (यानी दावा विफल रहता है):

model_nll = nn.Sequential(nn.Linear(3072, 1024),
                          nn.Tanh(),
                          nn.Linear(1024, 512),
                          nn.Tanh(),
                          nn.Linear(512, 128),
                          nn.Tanh(),
                          nn.Linear(128, 2),
                          nn.LogSoftmax(dim=1))


model_ce = nn.Sequential(nn.Linear(3072, 1024),
                          nn.Tanh(),
                          nn.Linear(1024, 512),
                          nn.Tanh(),
                          nn.Linear(512, 128),
                          nn.Tanh(),
                          nn.Linear(128, 2),
                          nn.LogSoftmax(dim=1))

loss_fn_ce = nn.CrossEntropyLoss()
loss_fn_nll = nn.NLLLoss()

t = torch.rand(1,3072)
target = torch.tensor([1])

with torch.no_grad():
    loss_nll = loss_fn_nll(model_nll(t), target)
    loss_ce = loss_fn_ce(model_ce(t), target)
    assert torch.eq(loss_nll, loss_ce)

मैं स्पष्ट रूप से यहाँ कुछ बुनियादी याद कर रहा हूँ।

0
Marcel Coetzee 9 सितंबर 2021, 21:52

2 जवाब

सबसे बढ़िया उत्तर

जैसा कि आपने देखा, वज़न बेतरतीब ढंग से आरंभ किया जाता है।

समान भार साझा करने वाले दो मॉड्यूल प्राप्त करने का एक तरीका केवल state_dict एक की स्थिति और दूसरे पर सेट करें load_state_dict

यह एक-लाइनर है:

>>> model_ce.load_state_dict(model_nll.state_dict())
2
Ivan 9 सितंबर 2021, 22:16

निम्नलिखित अभिकथन गुजरता है:

model = nn.Sequential(
    nn.Linear(3072, 1024),
    nn.Tanh(),
    nn.Linear(1024, 512),
    nn.Tanh(),
    nn.Linear(512, 128),
    nn.Tanh(),
    nn.Linear(128, 2),
)


loss_fn_nll = nn.NLLLoss()
loss_fn_ce = nn.CrossEntropyLoss()

t = torch.rand(1, 3072)
target = torch.tensor([1])

with torch.no_grad():

    loss_nll = loss_fn_nll(nn.LogSoftmax(dim=1)(model(t)), target)
    loss_ce = loss_fn_ce(model(t), target)

    assert torch.eq(loss_nll, loss_ce)

मुझे लगता है कि मूल प्रश्न में दो नेटवर्क में वज़न बेतरतीब ढंग से उत्पन्न होता है। torch.manual_seed(0) के साथ भी यह अभी भी समान नहीं है।

0
Marcel Coetzee 9 सितंबर 2021, 22:08