Q：例えば，MNISTをCNNモデルで学習するにはどのように変更を加えればいいですか

A：

```python
(略)
# MNISTのデータを取得
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])


trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)


testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=32, shuffle=False)

```

次に，ニューラルネットワークをCNNに変更します．

```python
# CNNの定義
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5) # 1 input channel, 6 output channels, 5x5 kernel
        self.pool = nn.MaxPool2d(2, 2) # 2x2 max pooling
        self.conv2 = nn.Conv2d(6, 16, 5) # 6 input channel, 16 output channels, 5x5 kernel
        self.fc1 = nn.Linear(16 * 4 * 4, 120) # fully connected layer
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10) # 10 classes for MNIST


    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) # Conv -> ReLU -> Max Pooling
        x = self.pool(F.relu(self.conv2(x))) # Conv -> ReLU -> Max Pooling
        x = x.view(-1, 16 * 4 * 4) # flatten
        x = F.relu(self.fc1(x)) # FC -> ReLU
        x = F.relu(self.fc2(x)) # FC -> ReLU
        x = self.fc3(x) # FC
        return x
```

最後に，学習ループを以下のように設定します．

```python
# モデル，損失関数，最適化アルゴリズムの定義
model = Net()
criterion = nn.CrossEntropyLoss() # changed from MSE to cross entropy loss for classification
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 学習ループ
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(trainloader)))

print('Finished Training')

```
