In [42]:
# Re-using the code from previous notebooks
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, 6, 5, padding=2),
            nn.Sigmoid(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(6, 16, 5),
            nn.Sigmoid(),
            nn.MaxPool2d(2, stride=2),
            nn.Flatten(),
            nn.Linear(400,120),
            nn.Sigmoid(),
            nn.Linear(120,84),
            nn.Sigmoid(),
            nn.Linear(84,10),
        )

    def forward(self, x):
        logits = self.network(x)
        return logits



In [43]:
model = NeuralNetwork()
print(model)

sample = next(iter(train_dataloader))
print((sample[0].shape))
print((sample[1].shape))

NeuralNetwork(
  (network): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): Sigmoid()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): Sigmoid()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=400, out_features=120, bias=True)
    (8): Sigmoid()
    (9): Linear(in_features=120, out_features=84, bias=True)
    (10): Sigmoid()
    (11): Linear(in_features=84, out_features=10, bias=True)
  )
)
torch.Size([64, 1, 28, 28])
torch.Size([64])


In [44]:
# set hyperparameters
learning_rate = 1e-2
batch_size = 64
epochs = 10

# Initialize the loss function
# In this case, we use CrossEntropyLoss for classification
# Regression problems would use MSELoss
loss_fn = nn.CrossEntropyLoss()

# Initialize the optimizer, here: Stochastic Gradient Descent
# other options: Adam, RMSprop, etc.
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

In [None]:
# loops over our optimization code
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    print(f"current learning rate: {learning_rate}")
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    lr_scheduler.step()


In [45]:
# evaluate the model's performance against the test dataset
def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [46]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
current learning rate: 0.01
loss: 2.351394  [   64/60000]
loss: 1.285827  [ 6464/60000]
loss: 0.378701  [12864/60000]
loss: 0.259473  [19264/60000]
loss: 0.129985  [25664/60000]
loss: 0.155150  [32064/60000]
loss: 0.019095  [38464/60000]
loss: 0.161486  [44864/60000]
loss: 0.368852  [51264/60000]
loss: 0.180827  [57664/60000]
Test Error: 
 Accuracy: 97.1%, Avg loss: 0.098581 

Epoch 2
-------------------------------
current learning rate: 0.01
loss: 0.095380  [   64/60000]
loss: 0.107501  [ 6464/60000]
loss: 0.032343  [12864/60000]
loss: 0.064993  [19264/60000]
loss: 0.064923  [25664/60000]
loss: 0.082870  [32064/60000]
loss: 0.023708  [38464/60000]
loss: 0.062820  [44864/60000]
loss: 0.294651  [51264/60000]
loss: 0.089611  [57664/60000]
Test Error: 
 Accuracy: 98.2%, Avg loss: 0.055196 

Epoch 3
-------------------------------
current learning rate: 0.01
loss: 0.019176  [   64/60000]
loss: 0.034591  [ 6464/60000]
loss: 0.009431  [12864/60000]
lo

In [None]:
torch.save(model, 'model.pth')
model = torch.load('model.pth')

In [48]:
with torch.no_grad():
  output = model(sample[0])
  print(len(output))
  for i in range(64):
    print(output[i].argmax(), sample[1][i])

64
tensor(5) tensor(5)
tensor(0) tensor(0)
tensor(4) tensor(4)
tensor(1) tensor(1)
tensor(9) tensor(9)
tensor(2) tensor(2)
tensor(1) tensor(1)
tensor(3) tensor(3)
tensor(1) tensor(1)
tensor(4) tensor(4)
tensor(3) tensor(3)
tensor(5) tensor(5)
tensor(3) tensor(3)
tensor(6) tensor(6)
tensor(1) tensor(1)
tensor(7) tensor(7)
tensor(2) tensor(2)
tensor(8) tensor(8)
tensor(6) tensor(6)
tensor(9) tensor(9)
tensor(4) tensor(4)
tensor(0) tensor(0)
tensor(9) tensor(9)
tensor(1) tensor(1)
tensor(1) tensor(1)
tensor(2) tensor(2)
tensor(4) tensor(4)
tensor(3) tensor(3)
tensor(2) tensor(2)
tensor(7) tensor(7)
tensor(3) tensor(3)
tensor(8) tensor(8)
tensor(6) tensor(6)
tensor(9) tensor(9)
tensor(0) tensor(0)
tensor(5) tensor(5)
tensor(6) tensor(6)
tensor(0) tensor(0)
tensor(7) tensor(7)
tensor(6) tensor(6)
tensor(1) tensor(1)
tensor(8) tensor(8)
tensor(7) tensor(7)
tensor(9) tensor(9)
tensor(3) tensor(3)
tensor(9) tensor(9)
tensor(8) tensor(8)
tensor(5) tensor(5)
tensor(9) tensor(9)
tensor(3) tensor(