PyTorch-Ignite PyTorch-Ignite

How to load checkpoint and resume training

In this example, we will be using a ResNet18 model on the MNIST dataset. The base code is the same as used in the Getting Started Guide.

Required Dependencies

!pip install pytorch-ignite -q

Basic Setup

import torch
from torch import nn
from import DataLoader
from torchvision.datasets import MNIST
from torchvision.models import resnet18
from torchvision.transforms import Compose, Normalize, ToTensor

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import Checkpoint, global_step_from_engine
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.model = resnet18(num_classes=10)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        return self.model(x)

model = Net().to(device)

data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

train_loader = DataLoader(
    MNIST(download=True, root=".", transform=data_transform, train=True),

val_loader = DataLoader(
    MNIST(download=True, root=".", transform=data_transform, train=False),

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()
Downloading to ./MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading to ./MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading to ./MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading to ./MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
evaluator = create_supervised_evaluator(
    model, metrics={"accuracy": Accuracy(), "loss": Loss(criterion)}, device=device

def log_validation_results(trainer):
    metrics = evaluator.state.metrics
        f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}"


We can use Checkpoint() as shown below to save the latest model after each epoch is completed. to_save here also saves the state of the optimizer and trainer in case we want to load this checkpoint and resume training.

to_save = {'model': model, 'optimizer': optimizer, 'trainer': trainer}
checkpoint_dir = "checkpoints/"

checkpoint = Checkpoint(
evaluator.add_event_handler(Events.COMPLETED, checkpoint)
< at 0x7f1a8490c090>

Start Training

Finally, we start the engine on the training dataset and run it for 2 epochs:, max_epochs=2)
Validation Results - Epoch[1] Avg accuracy: 0.96 Avg loss: 0.16
Validation Results - Epoch[2] Avg accuracy: 0.98 Avg loss: 0.07

	iteration: 938
	epoch: 2
	epoch_length: 469
	max_epochs: 2
	output: 0.026344267651438713
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class ''>
	seed: <class 'NoneType'>
	times: <class 'dict'>

Load Checkpoint

Now let’s assume, we have reset our model, optimizer and trainer. After instantiating these objects again, we need to resume training from the checkpoint that we have saved.

!ls checkpoints

We can use load_objects() to apply the state of our checkpoint to the objects stored in to_save.

checkpoint_fp = checkpoint_dir + ""
checkpoint = torch.load(checkpoint_fp, map_location=device) 
Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) 

Resume Training, max_epochs=4)
Validation Results - Epoch[3] Avg accuracy: 0.99 Avg loss: 0.04
Validation Results - Epoch[4] Avg accuracy: 0.98 Avg loss: 0.06

	iteration: 1876
	epoch: 4
	epoch_length: 469
	max_epochs: 4
	output: 0.0412273071706295
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class ''>
	seed: <class 'NoneType'>
	times: <class 'dict'>