PyTorch-Ignite PyTorch-Ignite

State

A state is introduced in Engine to store the output of the process_function, current epoch, iteration and other helpful information. Each Engine contains a State, which includes the following:

  • engine.state.seed: Seed to set at each data “epoch”.
  • engine.state.epoch: Number of epochs the engine has completed. Initializated as 0 and the first epoch is 1.
  • engine.state.iteration: Number of iterations the engine has completed. Initialized as 0 and the first iteration is 1.
  • engine.state.max_epochs: Number of epochs to run for. Initializated as 1.
  • engine.state.output: The output of the process_function defined for the Engine. See below.
  • etc

Other attributes can be found in the docs of State.

In the code below, engine.state.output will store the batch loss. This output is used to print the loss at every iteration.

def update(engine, batch):
    x, y = batch
    y_pred = model(inputs)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

def on_iteration_completed(engine):
    iteration = engine.state.iteration
    epoch = engine.state.epoch
    loss = engine.state.output
    print(f"Epoch: {epoch}, Iteration: {iteration}, Loss: {loss}")

trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)

Since there is no restrictions on the output of process_function, Ignite provides output_transform argument for its ignite.metrics and ignite.handlers. Argument output_transform is a function used to transform engine.state.output for intended use. Below we’ll see different types of engine.state.output and how to transform them.

In the code below, engine.state.output will be a list of loss, y_pred, y for the processed batch. If we want to attach Accuracy to the engine, output_transform will be needed to get y_pred and y from engine.state.output. Let’s see how that is done:

def update(engine, batch):
    x, y = batch
    y_pred = model(inputs)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item(), y_pred, y

trainer = Engine(update)

@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):
    epoch = engine.state.epoch
    loss = engine.state.output[0]
    print (f'Epoch {epoch}: train_loss = {loss}')

accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

Similar to above, but this time the output of the process_function is a dictionary of loss, y_pred, y for the processed batch, this is how the user can use output_transform to get y_pred and y from engine.state.output. See below:

def update(engine, batch):
    x, y = batch
    y_pred = model(inputs)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return {'loss': loss.item(),
            'y_pred': y_pred,
            'y': y}

trainer = Engine(update)

@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):
    epoch = engine.state.epoch
    loss = engine.state.output['loss']
    print (f'Epoch {epoch}: train_loss = {loss}')

accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

Note:

A good practice is to use State also as a storage of user data created in update or handler functions. For example, we would like to save new_attribute in the state:

def user_handler_function(engine):
    engine.state.new_attribute = 12345