मैंने पाइटोरच प्रलेखन का पालन किया है और एमएनआईएसटी डेटासेट के लिए एक अत्यंत सरल क्लासिफायरियर बनाया है। नीचे मेरा कोड है:

import numpy as np

import torch
import torchvision
from torchvision import transforms, datasets

import torch.nn as nn
import torch.nn.functional as F

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
    ])

train = datasets.MNIST('', train=True, download=True, transform=transform)
test = datasets.MNIST('', train=False, download=True, transform=transform)
trainset = torch.utils.data.DataLoader(train, batch_size=1, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=1, shuffle=False)

class Classifier(nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Classifier, self).__init__()
        self.linear_1 = torch.nn.Linear(D_in, H)
        self.linear_2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = self.linear_1(x).clamp(min=0)
        x = self.linear_2(x)
        return F.log_softmax(x, dim=1)


net = Classifier(28*28, 128, 10)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

for epoch in range(3):
    running_loss = 0.0
    for X, label in iter(trainset):
        X = X.view(28*28, -1)

        optimizer.zero_grad()

        output = net(torch.flatten(X))
        loss = nn.CrossEntropyLoss(output, label)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}')
            running_loss = 0.0
print("Finished training.")

torch.save(net.state_dict(), './classifier.pth')

किसी कारण से, मुझे आउटपुट मिल रहा है

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

लाइन पर: output = net(torch.flatten(X)

आपकी मदद के लिए अग्रिम धन्यवाद!

1
Shiv Bhatia 21 जुलाई 2020, 15:25

1 उत्तर

सबसे बढ़िया उत्तर

जब flatten() आप सभी आयाम सहित बैच आयाम हटा देते हैं!

प्रयत्न:

output = net(x.view(x.shape[0], -1))
2
Shai 21 जुलाई 2020, 15:36