PyTorch-Ignite PyTorch-Ignite

Reinforcement Learning with Ignite

In this tutorial we will implement a policy gradient based algorithm called Reinforce and use it to solve OpenAI’s Cartpole problem using PyTorch-Ignite.


The reader should be familiar with the basic concepts of Reinforcement Learning like state, action, environment, etc.

The Cartpole Problem

We have to balance a Cartpole which is a pole-like structure attached to a cart. The cart is free to move across the frictionless surface. We can balance the cartpole by moving the cart left or right in 1D. Let’s start by defining a few terms.


There are 4 variables on which the environment depends: cart position and velocity, pole position and velocity.

Action space

There are 2 possible actions that the agent can perform: left or right direction.


For each instance of the cartpole not toppling down or going out of range, we have a reward of 1.

When is it solved?

The problem is considered solved when the average reward is greater than reward_threshold defined for the environment.

Required Dependencies

!pip install gymnasium pytorch-ignite

On Colab

We need additional dependencies to render the environment on Google Colab.

!apt-get install -y xvfb python-opengl
!pip install pyvirtualdisplay
!pip install --upgrade pygame moviepy


from collections import deque
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

from ignite.engine import Engine, Events

import gymnasium as gym
from gymnasium.wrappers import RecordVideo

import glob
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display

Configurable Parameters

We will use these values later in the tutorial at appropriate places.

seed_val = 543
gamma = 0.99
log_interval = 100
max_episodes = 1000000
render = True

Setting up the Environment

Let’s load our environment first.

env = gym.make("CartPole-v0", render_mode="rgb_array")

On Colab

If on Google Colab, we need to follow a list of steps to render the output. First we initialize our screen size.

display = Display(visible=0, size=(1400, 900))
<pyvirtualdisplay.display.Display at 0x7f76f00bf810>

Below we have a utility function to enable video recording of the gym environment. To enable video, we have to wrap our environment in this function.

def wrap_env(env):
  env = RecordVideo(env, './video', disable_logger=True)
  return env

env = wrap_env(env)


We are going to utilize the reinforce algorithm in which our agent will use episode samples from starting state to goal state directly from the environment. Our model has two linear layers with 4 in features and 2 out features for 4 state variables and 2 actions respectively. We also define an action buffer as saved_log_probs and rewards. We also have an intermediate ReLU layer through which the outputs of the 1st layer are passed to receive the score for each action taken. Finally, we return a list of probabilities for each of these actions.

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 2)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)

And then we initialize our model, optimizer, epsilon and timesteps.

TimeStep is the object which contains information about a state like current observation, type of the step, reward, and discount. Given that some action is performed on some state, it gives the new state, type of the new step (or state), discount, and reward achieved.

policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()
timesteps = range(10000)

Create Trainer

Ignite’s Engine allows users to define a process_function to run one episode. We select an action from the policy, then take the action through step() and finally increment our reward. If the problem is solved, we terminate training and save the timestep.

An episode is an instance of a game (or life of a game). If the game ends or life decreases, the episode ends. Step, on the other hand, is the time or some discrete value which increases monotonically in an episode. With each change in the state of the game, the value of step increases until the game ends.

def run_single_timestep(engine, timestep):
    observation = engine.state.observation
    action = select_action(policy, observation)
    engine.state.observation, reward, done, _, _ = env.step(action)
    if render:

    engine.state.ep_reward += reward

    if done:
        engine.state.timestep = timestep

trainer = Engine(run_single_timestep)

Next we need to select an action to take. After we get a list of probabilities, we create a categorical distribution over them and sample an action from that. This is then saved to the action buffer and the action to take is returned (left or right).

def select_action(policy, observation):
    state = torch.from_numpy(observation).float().unsqueeze(0)
    probs = policy(state)
    m = Categorical(probs)
    action = m.sample()
    return action.item()

We initialize a list to save policy loss and true returns of the rewards returned from the environment. Then we calculate the policy losses from the advantage (-log_prob * reward). Finally, we reset the gradients, perform backprop on the policy loss and reset the rewards and actions buffer.

def finish_episode(policy, optimizer, gamma):
    R = 0
    policy_loss = []
    returns = deque()
    for r in policy.rewards[::-1]:
        R = r + gamma * R
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)

    for log_prob, R in zip(policy.saved_log_probs, returns):
        policy_loss.append(-log_prob * R)

    policy_loss =

    del policy.rewards[:]
    del policy.saved_log_probs[:]

Attach handlers to run on specific events

We rename the start and end epoch events for easy understanding.


Before training begins, we initialize the reward in trainer’s state.

trainer.state.running_reward = 10

When an episode begins, we have to reset the environment’s state.

def reset_environment_state():
    torch.manual_seed(seed_val + trainer.state.epoch)
    trainer.state.observation, _ = env.reset(seed=seed_val + trainer.state.epoch)
    trainer.state.ep_reward = 0

When an episode finishes, we update the running reward and perform backpropagation by calling finish_episode().

def update_model():
    trainer.state.running_reward = 0.05 * trainer.state.ep_reward + (1 - 0.05) * trainer.state.running_reward
    finish_episode(policy, optimizer, gamma)

After that, every 100 (log_interval) episodes, we log the results.

def log_episode():
    i_episode = trainer.state.epoch
        f"Episode {i_episode}\tLast reward: {trainer.state.ep_reward:.2f}"
        f"\tAverage length: {trainer.state.running_reward:.2f}"

And finally, we check if our running reward has crossed the threshold so that we can stop training.

def should_finish_training():
    running_reward = trainer.state.running_reward
    if running_reward > env.spec.reward_threshold:
            f"Solved! Running reward is now {running_reward} and "
            f"the last episode runs to {trainer.state.timestep} time steps!"
        trainer.should_terminate = True

Run Trainer, max_epochs=max_episodes)
Episode 100	Last length:    66	Average length: 37.90
Episode 200	Last length:    21	Average length: 115.82
Episode 300	Last length:   199	Average length: 133.13
Episode 400	Last length:    98	Average length: 134.97
Episode 500	Last length:    77	Average length: 77.39
Episode 600	Last length:   199	Average length: 132.99
Episode 700	Last length:   122	Average length: 137.40
Episode 800	Last length:    39	Average length: 159.51
Episode 900	Last length:    86	Average length: 113.31
Episode 1000	Last length:    76	Average length: 114.67
Episode 1100	Last length:    96	Average length: 98.65
Episode 1200	Last length:    90	Average length: 84.50
Episode 1300	Last length:   102	Average length: 89.10
Episode 1400	Last length:    64	Average length: 86.45
Episode 1500	Last length:    60	Average length: 76.35
Episode 1600	Last length:    75	Average length: 71.38
Episode 1700	Last length:   176	Average length: 117.25
Episode 1800	Last length:   139	Average length: 140.96
Episode 1900	Last length:    63	Average length: 141.79
Episode 2000	Last length:    66	Average length: 94.01
Episode 2100	Last length:   199	Average length: 115.46
Episode 2200	Last length:   113	Average length: 137.11
Episode 2300	Last length:   174	Average length: 135.36
Episode 2400	Last length:    80	Average length: 116.46
Episode 2500	Last length:    96	Average length: 101.47
Episode 2600	Last length:   199	Average length: 141.13
Episode 2700	Last length:    13	Average length: 134.91
Episode 2800	Last length:    90	Average length: 71.22
Episode 2900	Last length:    61	Average length: 70.14
Episode 3000	Last length:   199	Average length: 129.67
Episode 3100	Last length:   199	Average length: 173.62
Episode 3200	Last length:   199	Average length: 189.30
Solved! Running reward is now 195.03268327777783 and the last episode runs to 199 time steps!

	iteration: 396569
	epoch: 3289
	epoch_length: 10000
	max_epochs: 1000000
	output: <class 'NoneType'>
	batch: 199
	metrics: <class 'dict'>
	dataloader: <class 'list'>
	seed: <class 'NoneType'>
	times: <class 'dict'>
	running_reward: 195.03268327777783
	observation: <class 'numpy.ndarray'>
	timestep: 199

On Colab

Finally, we can view our saved video.

mp4list = glob.glob('video/*.mp4')

if len(mp4list) > 0:
    mp4 = mp4list[-1]  # pick the last video
    video =, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
    print("Could not find video")

That’s it! We have successfully solved the Cartpole problem!