DQN

The goal of this exercise is to implement DQN and to apply it to the cartpole balancing problem.

try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    !pip install -U gymnasium pygame moviepy
    !pip install gymnasium[box2d]
import numpy as np
rng = np.random.default_rng()
import matplotlib.pyplot as plt
import os
from IPython.display import clear_output
from collections import deque

import gymnasium as gym
print("gym version:", gym.__version__)

import pygame
from moviepy.editor import ImageSequenceClip, ipython_display

import tensorflow as tf
import logging
tf.get_logger().setLevel(logging.ERROR)

class GymRecorder(object):
    """
    Simple wrapper over moviepy to generate a .gif with the frames of a gym environment.
    
    The environment must have the render_mode `rgb_array_list`.
    """
    def __init__(self, env):
        self.env = env
        self._frames = []

    def record(self, frames):
        "To be called at the end of an episode."
        for frame in frames:
            self._frames.append(np.array(frame))

    def make_video(self, filename):
        "Generates the gif video."
        directory = os.path.dirname(os.path.abspath(filename))
        if not os.path.exists(directory):
            os.mkdir(directory)
        self.clip = ImageSequenceClip(list(self._frames), fps=self.env.metadata["render_fps"])
        self.clip.write_gif(filename, fps=self.env.metadata["render_fps"], loop=0)
        del self._frames
        self._frames = []

def running_average(x, N):
    kernel = np.ones(N) / N
    return np.convolve(x, kernel, mode='same')
gym version: 0.26.3

Cartpole balancing task

We are going to use the Cartpole balancing problem, which can be loaded with:

gym.make('CartPole-v0')

States have 4 continuous values (position and speed of the cart, angle and speed of the pole) and 2 discrete outputs (going left or right). The reward is +1 for each transition where the pole is still standing (angle of less than 30° with the vertical).

In CartPole-v0, the episode ends when the pole fails or after 200 steps. In CartPole-v1, the maximum episode length is 500 steps, which is too long for us, so we stick to v0 here.

The maximal (undiscounted) return is therefore 200. Can DQN learn this?

# Create the environment
env = gym.make('CartPole-v0', render_mode="rgb_array_list")
recorder = GymRecorder(env)

# Sample the initial state
state, info = env.reset()

# One episode:
done = False
return_episode = 0
while not done:

    # Select an action randomly
    action = env.action_space.sample()
    
    # Sample a single transition
    next_state, reward, terminal, truncated, info = env.step(action)

    # End of the episode
    done = terminal or truncated

    # Update undiscounted return
    return_episode += reward
    
    # Go in the next state
    state = next_state

print("Return:", return_episode)

recorder.record(env.render())
video = "videos/cartpole.gif"
recorder.make_video(video)
ipython_display(video)
Return: 19.0
MoviePy - Building file videos/cartpole.gif with imageio.
                                                                                                        

As the problem is quite simple (4 state variables, 2 actions), DQN can run on a single CPU. However, we advise that you run the notebook on a GPU in Colab to avoid emptying the battery of your laptop too fast or making it too warm as training takes quite a long time.

We will stop from now on to display the cartpole on colab, as we want to go fast.

Creating the model

The first step is to create the value network using keras. We will not need anything fancy: a simple fully connected network with 4 input neurons, two hidden layers of 64 neurons each and 2 output neurons will do the trick. ReLU activation functions all along and the Adam optimizer.

Q: Which loss function should we use? Think about which arguments have to passed to model.compile() and what activation function is required in the output layer.

We will need to create two identical networks: the trained network and the target network. You should therefore create a method that returns a compiled model, so it can be called two times. You should pass it the environment (so the network can know how many input and output neurons it needs) and the learning rate for the Adam optimizer.

def create_model(env, lr):
    
    model = Sequential()

    # ...

    return model

Q: Implement the method accordingly.

def create_model(env, lr):
    
    model = tf.keras.models.Sequential()
    
    model.add(tf.keras.layers.Input(env.observation_space.shape))
    model.add(tf.keras.layers.Dense(64, activation='relu'))
    model.add(tf.keras.layers.Dense(64, activation='relu'))
    model.add(tf.keras.layers.Dense(env.action_space.n, activation='linear'))
    
    model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(learning_rate=lr))
    
    print(model.summary())

    return model

Let’s test this method by creating the trained and target networks.

Important: every time you call create_model, a new neural network will be instantiated but the previous ones will not be deleted. During this exercise, you may have to create hundreds of networks because of the incremental implementation of DQN: all networks will stay instantiated in the RAM, and your computer/colab tab will freeze after a while. Before creating new networks, delete all existing ones with:

tf.keras.backend.clear_session()

Q: Create the trained and target networks. The learning rate does not matter for now. Instantiate the Cartpole environment and print the output of both networks for the initial state (state, info = env.reset()). Are they the same?

Hint: model.predict(X, verbose=0) expects an array X of shape (N, 4), with N the number of examples. Here, we have only one example, so make sure to reshape state so it has the shape (1, 4) (otherwise tf will complain).

env = gym.make('CartPole-v0')

state, info = env.reset()
state = state.reshape((1, env.observation_space.shape[0]))
print("State:", state)

tf.keras.backend.clear_session()
trained_model = create_model(env, 0.001)
target_model = create_model(env, 0.001)

trained_prediction = trained_model.predict(state, verbose=0)[0]
target_prediction = target_model.predict(state, verbose=0)[0]

print('-'*10)
print("State:", state)
print("Prediction for the trained network:", trained_prediction)
print("Prediciton for the target network:", target_prediction)
print('-'*10)
State: [[0.01271655 0.00693787 0.00868186 0.00911883]]
Metal device set to: Apple M1 Pro
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 64)                320       
                                                                 
 dense_1 (Dense)             (None, 64)                4160      
                                                                 
 dense_2 (Dense)             (None, 2)                 130       
                                                                 
=================================================================
Total params: 4,610
Trainable params: 4,610
Non-trainable params: 0
_________________________________________________________________
None
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_3 (Dense)             (None, 64)                320       
                                                                 
 dense_4 (Dense)             (None, 64)                4160      
                                                                 
 dense_5 (Dense)             (None, 2)                 130       
                                                                 
=================================================================
Total params: 4,610
Trainable params: 4,610
Non-trainable params: 0
_________________________________________________________________
None
2022-11-25 11:15:13.944109: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-11-25 11:15:13.944323: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2022-11-25 11:15:14.117562: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2022-11-25 11:15:14.148074: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
----------
State: [[0.01271655 0.00693787 0.00868186 0.00911883]]
Prediction for the trained network: [0.00114096 0.00075357]
Prediciton for the target network: [-0.00087994  0.00059144]
----------
2022-11-25 11:15:14.260500: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.

The target network has the same structure as the trained network, but not the same weights, as they are randomly initialized. We want the target network \theta' to have exactly the same weights as the trained weights \theta. You can obtain the weights of a network with:

w = model.get_weights()

and set weights using:

model.set_weights(w)

Q: Transfer the weights of the trained model to the target model. Compare their predictions for the current state.

target_model.set_weights(trained_model.get_weights())

trained_prediction = trained_model.predict(state, verbose=0)[0]
target_prediction = target_model.predict(state, verbose=0)[0]

print('-'*10)
print("State:", state)
print("Prediction for the trained network:", trained_prediction)
print("Prediciton for the target network:", target_prediction)
print('-'*10)
----------
State: [[0.01271655 0.00693787 0.00868186 0.00911883]]
Prediction for the trained network: [0.00114096 0.00075357]
Prediciton for the target network: [0.00114096 0.00075357]
----------

Experience replay memory

The second thing that we need is the experience replay memory (or replay buffer). We need a container like a python list where we append (s, a, r, s’, done) transitions (as in Q-learning), but with a maximal capacity: when there are already C transitions in the list, one should stop appending to the list, but rather start writing at the beginning of the list.

This would not be very hard to write, but it would take a lot of time and the risk is high to have hard-to-notice bugs.

Here is a basic implementation of the replay buffer using double-ended queues (deque). A deque is list with a maximum capacity. If the deque is full, it starts writing again at the beginnning. Exactly what we need. This implementation uses one deque per element in (s, a, r, s’, done), but one could also append the whole transition to a single deque.

Q: Read the code of the ReplayBuffer and understand what it does.

class ReplayBuffer:
    "Basic implementation of the experience replay memory using separated deques."
    def __init__(self, max_capacity):
        self.max_capacity = max_capacity
        
        # deques for each element
        self.states = deque(maxlen=max_capacity)
        self.actions = deque(maxlen=max_capacity)
        self.rewards = deque(maxlen=max_capacity)
        self.next_states = deque(maxlen=max_capacity)
        self.dones = deque(maxlen=max_capacity)
        
    def append(self, state, action, reward, next_state, done):
        # Store data
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.next_states.append(next_state)
        self.dones.append(done)
        
    def sample(self, batch_size):
        # Do not return samples if we do not have at least 2*batch_size transitions
        if len(self.states) < 2*batch_size: 
            return []
            
        # Randomly choose the indices of the samples.
        indices = sorted(np.random.choice(np.arange(len(self.states)), batch_size, replace=False))

        # Return the corresponding
        return [np.array([self.states[i] for i in indices]), 
                np.array([self.actions[i] for i in indices]), 
                np.array([self.rewards[i] for i in indices]), 
                np.array([self.next_states[i] for i in indices]), 
                np.array([self.dones[i] for i in indices])]

Q: Run a random agent on Cartpole (without rendering) for a few episodes and append each transition to a replay buffer with small capacity (e.g. 100). Sample a batch to check that everything makes sense.

env = gym.make('CartPole-v0')

buffer = ReplayBuffer(100)

for episode in range(10):
            
    # Reset
    state, info = env.reset()
    done = False
    
    # Sample the episode
    while not done:
        
        # Select an action randomly
        action = env.action_space.sample()
            
        # Perform the action
        next_state, reward, terminal, truncated, info = env.step(action)

        # End of the episode
        done = terminal or truncated
                
        # Store the transition
        buffer.append(state, action, reward, next_state, done)
                
        # Go in the next state
        state = next_state
    
# Sample a minibatch
batch = buffer.sample(10)
print("States:", batch[0])
print("Actions:", batch[1])
print("Rewards:", batch[2])
print("Next states:", batch[3])
print("Dones:", batch[4])
States: [[-7.1889226e-05 -1.9597625e-02  1.1202861e-02  1.0221607e-02]
 [-5.1598279e-03 -2.1528986e-01  1.7882740e-02  3.1551802e-01]
 [-7.8443229e-02 -8.0612069e-01  1.3144340e-01  1.3234164e+00]
 [-4.6816234e-02  2.0717475e-01  4.1258741e-02 -2.8127652e-01]
 [-4.2442955e-02 -1.8412507e-01  3.6115777e-02  3.2783771e-01]
 [-9.1748878e-02 -3.8812301e-01  1.3098954e-01  8.2237107e-01]
 [ 7.3674910e-02  6.0291737e-01 -7.6641589e-02 -8.1436384e-01]
 [ 8.5733257e-02  7.9900050e-01 -9.2928864e-02 -1.1301358e+00]
 [ 1.2985145e-01  6.0852367e-01 -1.5678419e-01 -9.4639248e-01]
 [ 1.5025000e-02  7.9516947e-01 -4.6248329e-03 -1.1131814e+00]]
Actions: [0 0 1 0 0 0 1 0 0 1]
Rewards: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Next states: [[-4.6384174e-04 -2.1487844e-01  1.1407293e-02  3.0641800e-01]
 [-9.4656255e-03 -4.1066191e-01  2.4193101e-02  6.1378646e-01]
 [-9.4565637e-02 -6.1288112e-01  1.5791173e-01  1.0745906e+00]
 [-4.2672738e-02  1.1489304e-02  3.5633210e-02  2.4128472e-02]
 [-4.6125453e-02 -3.7974206e-01  4.2672534e-02  6.3168758e-01]
 [-9.9511340e-02 -5.8477044e-01  1.4743695e-01  1.1532161e+00]
 [ 8.5733257e-02  7.9900050e-01 -9.2928864e-02 -1.1301358e+00]
 [ 1.0171327e-01  6.0521013e-01 -1.1553158e-01 -8.6798614e-01]
 [ 1.4202191e-01  4.1582093e-01 -1.7571205e-01 -7.0678967e-01]
 [ 3.0928390e-02  9.9035186e-01 -2.6888460e-02 -1.4073114e+00]]
Dones: [False False False False False False False False False False]

DQN agent

Here starts the fun part. There are a lot of things to do here, but you will now whether it works or not only when everything has been (correctly) implemented. So here is a lot of text to read carefully, and then you are on your own.

Reminder from the lecture:

  • Initialize value network Q_{\theta} and target network Q_{\theta'}.

  • Initialize experience replay memory \mathcal{D} of maximal size N.

  • for t \in [0, T_\text{total}]:

    • Select an action a_t based on Q_\theta(s_t, a), observe s_{t+1} and r_{t+1}.

    • Store (s_t, a_t, r_{t+1}, s_{t+1}) in the experience replay memory.

    • Every T_\text{train} steps:

      • Sample a minibatch \mathcal{D}_s randomly from \mathcal{D}.

      • For each transition (s_k, a_k, r_k, s'_k) in the minibatch:

        • Compute the target value t_k = r_k + \gamma \, \max_{a'} Q_{\theta'}(s'_k, a') using the target network.
      • Update the value network Q_{\theta} on \mathcal{D}_s to minimize:

      \mathcal{L}(\theta) = \mathbb{E}_{\mathcal{D}_s}[(t_k - Q_\theta(s_k, a_k))^2]

    • Every T_\text{target} steps:

      • Update target network: \theta' \leftarrow \theta.

Here is the skeleton of the DQNAgent class that you have to write:

class DQNAgent:
    
    def __init__(self, env, create_model, some_parameters):
        
        self.env = env
        
        # TODO: copy the parameters

        # TODO: Create the trained and target networks, copy the weights.

        # TODO: Create an instance of the replay memory
        
    def act(self, state):

        # TODO: Select an action using epsilon-greedy on the output of the trained model

        return action
    
    def update(self, batch):
        
        # TODO: train the model using the batch of transitions
        
        return loss # mse on the batch

    def train(self, nb_episodes):

        returns = []
        losses = []

        # TODO: Train the network for the given number of episodes

        return returns, losses

    def test(self):

        # TODO: one episode with epsilon temporarily set to 0

        return nb_steps # Should be 200 after learning

With this structure, it will be very simple to actually train the DQN on Cartpole:

# Create the environment
env = gym.make('CartPole-v1')

# Create the agent
agent = DQNAgent(env, create_model, other_parameters)

# Train the agent
returns, losses = agent.train(nb_episodes)

# Plot the returns
plt.figure(figsize=(10, 6))
plt.plot(returns)
plt.plot(running_mean(returns, 10))
plt.xlabel("Episodes")
plt.ylabel("Returns")

# Plot the losses
plt.figure(figsize=(10, 6))
plt.plot(losses)
plt.xlabel("Episodes")
plt.ylabel("Training loss")

plt.show()

# Test the network
nb_steps = agent.test()
print("Number of steps:", nb_steps)

So you “just” have to fill the holes.

1 - __init__(): Initializing the agent

In this method, you should first copy the value of the parameters as attributes: learning rate, epsilon, gamma and so on.

Suggested values: gamma = 0.99, learning_rate = 0.001

The second thing to do is to create the trained and target networks (with the same weights) and save them as attributes (the other methods will use them). Do not forget to clear the keras session first, otherwise the RAM will be quickly filled.

The third thing is to create an instance of the ERM. Use a buffer limit of 5000 transitions (should be passed as a parameter).

Do not hesitate to add other stuff as you implementing the other methods (e.g. counters).

2 - act(): action selection

We will use a simple \epsilon-greedy method for the action selection, as in the previous exercises.

The only difference is that we have to use the trained model to get the greedy action, using trained_model.predict(X, verbose=0). This will return the Q-value of the two actions left and right. Use argmax() to return the greedy action (with probability 1 - \epsilon). env.action_space.sample() should be used for the exploration (do not use the Q-network in that case, it is slow!).

\epsilon will be scheduled with an initial value of 1.0 and an exponential decay rate of 0.0005 after each action. It is always better to keep a little exploration, even if \epsilon has decayed to 0. Keep a minimal value of 0.05 for epsilon.

Q: Once this has been implemented, run your very slow random agent for 100 episodes to check everything works correctly.

3 - train(): training loop

This method will be very similar to the Q-learning agent that you implemented previously. Do not hesitate to copy and paste.

Here is the parts of the DQN algorithm that should be implemented:

  • for t \in [0, T_\text{total}]:

    • Select an action a_t based on Q_\theta(s_t, a), observe s_{t+1} and r_{t+1}.

    • Store (s_t, a_t, r_{t+1}, s_{t+1}) in the experience replay memory.

    • Every T_\text{train} steps:

      • Sample a minibatch \mathcal{D}_s randomly from \mathcal{D}.

      • Update the trained network using \mathcal{D}_s.

    • Every T_\text{target} steps:

      • Update target network: \theta' \leftarrow \theta.

The main difference with Q-learning is that update() will be called only every T_train = 4 steps: the number of updates to the trained network will be 4 times smaller that the number of steps made in the environment. Beware that if the ERM does not have enough transitions yet (less than the batch size), you should not call update().

Updating the target network (copying the weights of the trained network) should happen every 100 steps. Pass these parameters to the constructor of the agent.

The batch size can be set to 32.

4 - update(): training the value network

Using the provided minibatch, one should implement the following part of the DQN algorithm:

  • For each transition (s_k, a_k, r_k, s'_k) in the minibatch:

    • Compute the target value t_k = r_k + \gamma \, \max_{a'} Q_{\theta'}(s'_k, a') using the target network.
  • Update the value network Q_{\theta} on \mathcal{D}_s to minimize:

    \mathcal{L}(\theta) = \mathbb{E}_{\mathcal{D}_s}[(t_k - Q_\theta(s_k, a_k))^2]

So we just need to define the targets for each transition in the minibatch, and call model.fit() on the trained network to minimize the mse between the current predictions Q_\theta(s_k, a_k) and the target.

But we have a problem: the network has two outputs for the actions left and right, but we have only one target for the action that was executed. We cannot compute the mse between a vector with 2 elements and a single value… They must have the same size.

As we want only the train the output neuron corresponding to the action a_k, we are going to:

  1. Use the trained network to predict the Q-value of both actions [Q_\theta(s_k, 0), Q_\theta(s_k, 1)].
  2. Replace one of the values with the target, for example [Q_\theta(s_k, 0), t_k] if the second action was chosen.
  3. Minimize the mse between [Q_\theta(s_k, 0), Q_\theta(s_k, 1)] and [Q_\theta(s_k, 0), t_k].

That way, the first output neuron has a squared error of 0, so it won’t learn anything. Only the second output neuron will have a non-zero mse and learn.

There are more efficient ways to do this (using masks), but this will do the trick, the drawback being that we have to make a forward pass on the minibatch before calling fit().

The rest is pretty much the same as for your Q-learning agent. Do not forget that actions leading to a terminal state should only use the reward as a target, not the complete Bellman target r + \gamma \max Q.

Hint: as we sample a minibatch of 32 transitions, it is faster to call:

Q_values = np.array(training_model.predict_on_batch(states))

than:

Q_values = training_model.predict(states)

for reasons internal to tensorflow. Note that with tf2, you need to cast the result to numpy arrays as eager mode is now the default.

The method should return the training loss, which is contained in the History object returned by model.fit(). model.fit() should be called for one epoch only, a batch size of 32, and verbose set to 0.

5 - test()

This method should run one episode with epsilon set to 0, without learning. The number of steps should be returned (do not bother discounting with gamma, the goal is to be up for 200 steps).

Q: Let’s go! Run the agent for 150 episodes and observe how fast it manages to keep the pole up for 200 steps.

Beware that running the same network twice can lead to very different results. In particular, policy collapse (the network was almost perfect, but suddenly crashes and becomes random) can happen. Just be patient.

You can visualize a test trial using the GymRecorder: you just need to set the env attribute of your DQN agent to a new env with the render mode rgb_array_list and record the frames at the end.

class DQNAgent:
    
    def __init__(self, env, create_model, learning_rate, epsilon, epsilon_decay, gamma, batch_size, target_update_period, training_update_period, buffer_limit):
        self.env = env

        self.learning_rate = learning_rate
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.gamma = gamma
        self.batch_size = batch_size
        self.target_update_period = target_update_period
        self.training_update_period = training_update_period
        
        # Create the Q-network and the target network
        tf.keras.backend.clear_session() # start by deleting all existing models to be gentle on the RAM
        self.model = create_model(self.env, self.learning_rate)
        self.target_model = create_model(self.env, self.learning_rate)
        self.target_model.set_weights(self.model.get_weights())

        # Create the replay memory
        self.buffer = ReplayBuffer(buffer_limit)
                
    def act(self, state):

        # epsilon-greedy
        if np.random.rand() < self.epsilon: # Random selection
            action = self.env.action_space.sample()
        else: # Use the Q-network to get the greedy action
            action = self.model.predict(state.reshape((1, env.observation_space.shape[0])), verbose=0)[0].argmax()

        # Decay epsilon
        self.epsilon *= 1 - self.epsilon_decay
        self.epsilon = max(0.05, self.epsilon)

        return action
    
    def update(self, batch):
        
        # Get the minibatch
        states, actions, rewards, next_states, dones = batch 
        
        # Predict the Q-values in the current state
        targets = np.array(self.model.predict_on_batch(states))
        
        # Predict the Q-values in the next state using the target model
        next_Q_value = np.array(self.target_model.predict_on_batch(next_states)).max(axis=1)
        
        # Terminal states have a value of 0
        next_Q_value[dones] = 0.0
        
        # Compute the target
        for i in range(self.batch_size):
            targets[i, actions[i]] = rewards[i] + self.gamma * next_Q_value[i]
            
        # Train the model on the minibatch
        history = self.model.fit(states, targets, epochs=1, batch_size=self.batch_size, verbose=0)
        
        return history.history['loss'][0]

    def train(self, nb_episodes):

        steps = 0
        returns = []
        losses = []

        for episode in range(nb_episodes):
            
            # Reset
            state, info = self.env.reset()
            done = False
            steps_episode = 0
            return_episode = 0

            loss_episode = []
            
            # Sample the episode
            while not done:

                # Select an action 
                action = self.act(state)
            
                # Perform the action
                next_state, reward, terminal, truncated, info = self.env.step(action)

                # End of the episode
                done = terminal or truncated
                
                # Store the transition
                self.buffer.append(state, action, reward, next_state, done)
            
                # Sample a minibatch
                batch = self.buffer.sample(self.batch_size)
                
                # Train the NN on the minibatch
                if len(batch) > 0 and steps % self.training_update_period == 0:
                    loss = self.update(batch)
                    loss_episode.append(loss)

                # Update the target model
                if steps > self.target_update_period and steps % self.target_update_period == 0:
                    self.target_model.set_weights(self.model.get_weights())
            
                # Go in the next state
                state = next_state
                
                # Increment time
                steps += 1
                steps_episode += 1
                return_episode += reward
                    
                if done:
                    break
            
            # Store info
            returns.append(return_episode)
            losses.append(np.mean(loss_episode))

            # Print info
            clear_output(wait=True)
            print('Episode', episode+1)
            print(' total steps:', steps)
            print(' length of the episode:', steps_episode)
            print(' return of the episode:', return_episode)
            print(' current loss:', np.mean(loss_episode))
            print(' epsilon:', self.epsilon)

        return returns, losses

    def test(self, render=True):

        old_epsilon = self.epsilon
        self.epsilon = 0.0
        
        state, info = self.env.reset()
        nb_steps = 0
        done = False
        
        while not done:
            action = self.act(state)
            next_state, reward, terminal, truncated, info = self.env.step(action)
            done = terminal or truncated
            state = next_state
            nb_steps += 1
        
        self.epsilon = old_epsilon
        return nb_steps
# Parameters
nb_episodes = 150
batch_size = 32

epsilon = 1.0
epsilon_decay = 0.0005

gamma = 0.99

learning_rate = 0.005 
buffer_limit = 5000
target_update_period = 100
training_update_period = 4
# Create the environment
env = gym.make('CartPole-v0')

# Create the agent
agent = DQNAgent(env, create_model, learning_rate, epsilon, epsilon_decay, gamma, batch_size, target_update_period, training_update_period, buffer_limit)

# Train the agent
returns, losses = agent.train(nb_episodes)

# Plot the returns
plt.figure(figsize=(10, 6))
plt.plot(returns)
plt.plot(running_average(returns, 10))
plt.xlabel("Episodes")
plt.ylabel("Returns")

# Plot the losses
plt.figure(figsize=(10, 6))
plt.plot(losses)
plt.xlabel("Episodes")
plt.ylabel("Training loss")
plt.show()
Episode 150
 total steps: 17021
 length of the episode: 138
 return of the episode: 138.0
 current loss: 7.081971720286778
 epsilon: 0.05

# Test the network
env = gym.make('CartPole-v0', render_mode="rgb_array_list")
recorder = GymRecorder(env)
agent.env = env

nb_steps = agent.test()
recorder.record(env.render())
print("Number of steps:", nb_steps)

video = "videos/cartpole-dqn.gif"
recorder.make_video(video)
ipython_display(video, loop=0, autoplay=1)
Number of steps: 162
MoviePy - Building file videos/cartpole-dqn.gif with imageio.
                                                                                                        

Q: How does the loss evolve? Does it make sense?

A: The Q-values are non-stationary: the initial Q-values are very small (the agent fails almost immediately), while they are around 100 after training (200 steps with a reward of +1, but discounted with gamma). The mse increases with the magnitude of the Q-values, so the loss is a poor indicator of the convergence of the network.

Reward scaling

Q: Do a custom test trial after training (i.e. do not call test(), but copy and adapt its code) and plot the Q-value of the selected action at each time step. Do you think it is a good output for the network? Could it explain why learning is so slow?

env = gym.make('CartPole-v0')
agent.env = env

# No exploration
agent.epsilon = 0.0
        
state, info = agent.env.reset()
done = False

Q_values = []

while not done:
    action = agent.act(state)
    Q_values.append(agent.model.predict(state.reshape((1, 4)), verbose=0)[0][action])
    next_state, reward, terminal, truncated, info = agent.env.step(action)
    done = terminal or truncated
    state = next_state

plt.figure(figsize=(10, 6))
plt.plot(Q_values)
plt.xlabel("Steps")
plt.ylabel("Q-value")
plt.show()

A: The predicted Q-values at the beginning of learning are close to 0, as the weights are randomly initialized. They should grow to around 100, which takes a lot of time. If the target Q-values were around 1, learning might be much faster.

Q: Implement reward scaling by dividing the received rewards by a fixed factor of 100 when computing the Bellman targets. That way, the final Q-values will be around 1, what may be much easier to learned.

Tip: in order to avoid a huge copy and paste, you can inherit from your DQNAgent and only reimplement the desired function:

class ScaledDQNAgent (DQNAgent):
    def update(self, batch):
        # Change the content of this function only

You should reduce a bit the learning rate (e.g. 0.001) as the magnitude of the targets has changed.

class ScaledDQNAgent(DQNAgent):
    
    def update(self, batch):
        
        # Get the minibatch
        states, actions, rewards, next_states, dones = batch 
        
        # Predict the Q-values in the current state
        targets = np.array(self.model.predict_on_batch(states))
        
        # Predict the Q-values in the next state using the target model
        next_Q_value = np.array(self.target_model.predict_on_batch(next_states)).max(axis=1)
        
        # Terminal states have a value of 0
        next_Q_value[dones] = 0.0
        
        # Compute the target
        for i in range(self.batch_size):
            targets[i, actions[i]] = rewards[i]/100. + self.gamma * next_Q_value[i]
            
        # Train the model on the minibatch
        history = self.model.fit(states, targets, epochs=1, batch_size=self.batch_size, verbose=0)
        
        return history.history['loss'][0]
# Create the environment
env = gym.make('CartPole-v0')

# Create the agent
learning_rate = 0.002
agent = ScaledDQNAgent(env, create_model, learning_rate, epsilon, epsilon_decay, gamma, batch_size, target_update_period, training_update_period, buffer_limit)

# Train the agent
returns, losses = agent.train(nb_episodes)

# Plot the returns
plt.figure(figsize=(10, 6))
plt.plot(returns)
plt.plot(running_average(returns, 10))
plt.xlabel("Episodes")
plt.ylabel("Returns")

# Plot the losses
plt.figure(figsize=(10, 6))
plt.plot(losses)
plt.xlabel("Episodes")
plt.ylabel("Training loss")
Episode 150
 total steps: 18540
 length of the episode: 200
 return of the episode: 200.0
 current loss: 0.0003936138337303419
 epsilon: 0.05
Text(0, 0.5, 'Training loss')

# Test the network
env = gym.make('CartPole-v0', render_mode="rgb_array_list")
recorder = GymRecorder(env)
agent.env = env

agent.epsilon = 0.0

# Q-values     
state, info = agent.env.reset()
done = False
Q_values = []

while not done:
    action = agent.act(state)
    Q_values.append(agent.model.predict(state.reshape((1, 4)), verbose=0)[0][action])
    next_state, reward, terminal, truncated, info = agent.env.step(action)
    done = terminal or truncated
    state = next_state

# Plot the Q-values
plt.figure(figsize=(10, 6))
plt.plot(Q_values)
plt.xlabel("Steps")
plt.ylabel("Q-value")
plt.show()

recorder.record(env.render())

video = "videos/cartpole-dqn-scaled.gif"
recorder.make_video(video)
ipython_display(video, loop=0, autoplay=1)
MoviePy - Building file videos/cartpole-dqn-scaled.gif with imageio.