In [None]:
!pip install lightning

In [None]:
# Re-using the code from previous notebooks
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
# new imports
import torch.nn.functional as F
import lightning as L

# creating MNIST datasets and dataloaders as before
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

In [None]:
# New: our lightning module

class LitNeuralNetwork(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.001)
    
    def training_step(self, batch, batch_idx):
        X, y = batch
        logits = self(X)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
        logits = self(X)
        loss = F.cross_entropy(logits, y)
        correct = (logits.argmax(1) == y).float().sum()
        correct /= len(y)
        self.log('val_loss', loss)
        self.log('Accuracy', correct)
        return loss


In [None]:
# for visualization during training: tensorboard
%load_ext tensorboard
%tensorboard --logdir .

In [None]:
# instantiating model and trainer, and training the model

callbacks = None
model = LitNeuralNetwork()
trainer = L.Trainer(max_epochs=10, 
                    callbacks=callbacks,
                    accelerator="cpu", devices=1)
trainer.fit(model, train_dataloader, test_dataloader)