Variational autoencoder

The goal of this exercise is to implement a VAE and apply it on the MNIST dataset. The code is adapted from the keras tutorial:

https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/generative/ipynb/vae.ipynb

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
# Fetch the MNIST data
(X_train, t_train), (X_test, t_test) = tf.keras.datasets.mnist.load_data()
print("Training data:", X_train.shape, t_train.shape)
print("Test data:", X_test.shape, t_test.shape)

# Normalize the values
X_train = X_train.reshape(-1, 28, 28, 1).astype('float32') / 255.
X_test = X_test.reshape(-1, 28, 28, 1).astype('float32') / 255.

# Mean removal
X_mean = np.mean(X_train, axis=0)
X_train -= X_mean
X_test -= X_mean

# One-hot encoding
T_train = tf.keras.utils.to_categorical(t_train, 10)
T_test = tf.keras.utils.to_categorical(t_test, 10)
Training data: (60000, 28, 28) (60000,)
Test data: (10000, 28, 28) (10000,)

As a reminder, a VAE is composed of two parts:

Two fundamental aspects of a VAE are not standard in keras:

  1. The sampling layer \mathbf{z} \sim \mathcal{N}(\mu_\mathbf{x}, \sigma_\mathbf{x}) using the reparameterization trick.
  2. The VAE loss:

\mathcal{L}(\theta, \phi) = \mathbb{E}_{\mathbf{x} \in \mathcal{D}, \xi \sim \mathcal{N}(0, 1)} [ - \log p_\theta(\mathbf{\mu_x} + \mathbf{\sigma_x} \, \xi) + \dfrac{1}{2} \, \sum_{k=1}^K (\mathbf{\sigma_x^2} + \mathbf{\mu_x}^2 -1 - \log \mathbf{\sigma_x^2})]

This will force us to dive a bit deeper into the mechanics of tensorflow, but it is not that difficult since the release of tensorflow 2.0 and the eager execution mode.

Gradient tapes: redefining the learning procedure

Let’s first have a look at how to define custom losses. There is an easier way to define custom losses with keras (https://keras.io/api/losses/#creating-custom-losses), but we will need this sightly more complicated variant for the VAE.

Let’s reuse the CNN you implemented last time using the functional API on MNIST, but not compile it yet:

def create_model():

    inputs = tf.keras.layers.Input((28, 28, 1))

    x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='valid')(inputs)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = tf.keras.layers.Dropout(0.5)(x)

    x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='valid')(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = tf.keras.layers.Dropout(0.5)(x)

    x = tf.keras.layers.Flatten()(x)

    x = tf.keras.layers.Dense(150, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.5)(x)

    outputs = tf.keras.layers.Dense(10, activation='softmax')(x)

    model = tf.keras.Model(inputs, outputs)
    print(model.summary())

    return model

In order to have access to the internals of the training procedure, one of the possible methods is to inherit the tf.keras.Model class and redefine the train_step and (optionally) test_step methods.

The following cell redefines a model for the previous CNN and minimizes the categorical cross-entropy while tracking the loss and accuracy, so it is completely equivalent to:

model.compile(
    loss="categorical_crossentropy", 
    optimizer=optimizer, 
    metrics=['accuracy'])

Have a look at the code, but we will go through it step by step afterwards.

class CNN(tf.keras.Model):

    def __init__(self):
        super(CNN, self).__init__()

        # Model
        self.model = create_model()

        # Metrics
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.accuracy_tracker = tf.keras.metrics.Accuracy(name="accuracy")

    @property
    def metrics(self):
        "Track the loss and accuracy"
        return [self.loss_tracker, self.accuracy_tracker]

    def train_step(self, data):
        
        # Get the data of the minibatch
        X, t = data
        
        # Use GradientTape to record everything we need to compute the gradient
        with tf.GradientTape() as tape:

            # Prediction using the model
            y = self.model(X, training=True)
            
            # Cross-entropy loss
            loss = tf.reduce_mean(
                tf.reduce_sum(
                    - t * tf.math.log(y), # Cross-entropy
                    axis=1 # First index is the batch size, the second is the classes
                )
            )
        
        # Compute gradients
        grads = tape.gradient(loss, self.trainable_weights)
        
        # Apply gradients using the optimizer
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        # Update metrics 
        self.loss_tracker.update_state(loss)
        true_class = tf.reshape(tf.argmax(t, axis=1), shape=(-1, 1))
        predicted_class = tf.reshape(tf.argmax(y, axis=1), shape=(-1, 1))
        self.accuracy_tracker.update_state(true_class, predicted_class)
        
        # Return a dict mapping metric names to current value
        return {"loss": self.loss_tracker.result(), 'accuracy': self.accuracy_tracker.result()} 

    def test_step(self, data):
        
        # Get data
        X, t = data
        
        # Prediction
        y = self.model(X, training=False)
            
        # Loss
        loss = tf.reduce_mean(
            tf.reduce_sum(
                    - t * tf.math.log(y), # Cross-entropy
                    axis=1
            )
        )
        
        # Update metrics 
        self.loss_tracker.update_state(loss)
        true_class = tf.reshape(tf.argmax(t, axis=1), shape=(-1, 1))
        predicted_class = tf.reshape(tf.argmax(y, axis=1), shape=(-1, 1))
        self.accuracy_tracker.update_state(true_class, predicted_class)
        
        # Return a dict mapping metric names to current value
        return {"loss": self.loss_tracker.result(), 'accuracy': self.accuracy_tracker.result()} 
                

The constructor of the new CNN class creates the model defined by create_model() and stores it as an attribute.

Note: it would be actually more logical to create layers directly here, as we now have a model containing a model, but this is simpler for the VAE architecture.

The constructor also defines the metrics that should be tracked when training. Here we track the loss and accuracy of the model, using objects of tf.keras.metrics (check https://keras.io/api/metrics/ for a list of metrics you can track).

The metrics are furthermore declared in the metrics property, so that you can now avoid passing metrics=['accuracy'] to compile(). The default Model only has 'loss' as a default metric.

class CNN(tf.keras.Model):

    def __init__(self):
        super(CNN, self).__init__()

        # Model
        self.model = create_model()

        # Metrics
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.accuracy_tracker = tf.keras.metrics.Accuracy(name="accuracy")

    @property
    def metrics(self):
        "Track the loss and accuracy"
        return [self.loss_tracker, self.accuracy_tracker]

The training procedure is defined in the train_step(data) method of the class.

    def train_step(self, data):
        
        # Get the data of the minibatch
        X, t = data

data is a minibatch of data iteratively passed by model.fit(). X and t are tensors (multi-dimensional arrays) representing the inputs and targets. On MNIST, X has the shape (batch_size, 28, 28, 1) and t is (batch_size, 10). The rest of the method defines the loss function, computes its gradient w.r.t the learnable parameters and pass it the optimizer to change their value.

To get the output of the network on the minibatch, one simply has to call:

y = self.model(X)

which returns a (batch_size, 10) tensor. However, this forward pass does not keep in memory the activity of the hidden layers: all it cares about is the prediction. But when applying backpropagation, you need this internal information to compute the gradient.

In tensorflow 2.x, you can force the model to record internal activity using the eager execution mode and gradient tapes (as in the tape of an audio recorder):

with tf.GradientTape() as tape:
    y = self.model(X, training=True)

It is not a big problem if you are not familiar with Python contexts: all you need to know is that the tape object will “see” everything that happens when calling y = self.model(X, training=True), i.e. it will record the hidden activations in the model.

The next thing to do inside the tape is to compute the loss of the model on the minibatch. Here we minimize the categorical cross-entropy:

\mathcal{L}(\theta) = \frac{1}{N} \, \sum_{i=1}^N \sum_{j=1}^C - t^i_j \, \log y^i_j

where N is the batch size, C the number of classes, t^i_j the j-th element of the i-th target vector and y^i_j the predicted probability for class j and the i-th sample.

We therefore need to take our two tensors t and y and compute that loss function, but recording everything (so inside the tape context).

There are several ways to do that, for example by calling directly the built-in categorical cross-entropy object of keras on the data:

loss = tf.keras.losses.CategoricalCrossentropy()(t, y)

Another way to do it is to realize that tensorflow tensors are completely equivalent to numpy arrays: you can apply mathematical operations (sum, element-wise multiplication, log, etc.) on them as if they were regular arrays (internally, that is another story…).

You can for example add t and two times y as they have the same shape:

loss = t + 2.0 * y

loss would then be a tensor of the same shape. You can get the shape of a tensor with tf.shape(loss) just like in numpy.

Mathematical operation are in the tf.math module (https://www.tensorflow.org/api_docs/python/tf/math), for example with the log:

loss = t + tf.math.log(y)

* is by default the element-wise multiplication:

loss = - t * tf.math.log(y)

Here, loss is still a (batch_size, 10) tensor. We still need to sum over the 10 classes and take the mean over the minibatch to get a single number.

Summing over the second dimension of this tensor can be done with tf.reduce_sum:

loss = tf.reduce_sum(
    - t * tf.math.log(y), 
    axis=1 # First index is the batch size, the second is the classes
)

This gives us a vector with batch_size elements containing the individual losses for the minibatch. In order to compute its mean over the minibatch, we only need to call tf.reduce_mean():

loss = tf.reduce_mean(
            tf.reduce_sum(
                - t * tf.math.log(y),
                axis=1 
            )
        )

That’s it, we have redefined the categorical cross-entropy loss function on a minibatch using elementary numerical operations! Doing this inside the tape allows tensorflow to keep track of each sample of the minibatch individually: otherwise, it would not know how the loss (a single number) depends on each prediction y^i and therefore on the parameters of the NN.

Now that we have the loss function as a function of the trainable parameters of the NN on the minibatch, we can ask tensorflow for its gradient:

grads = tape.gradient(loss, self.trainable_weights)

Backpropagation is still a one-liner. self.trainable_weights contains all weights and biases in the model, while tape.gradient() apply backpropagation to compute the gradient of the loss function w.r.t them.

We can then pass this gradient to the optimizer (SGD or Adam, which will be passed to compile()) so that it updates the parameters:

self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

Finally, we can update our metrics so that our custom loss and the accuracy are tracked during training:

self.loss_tracker.update_state(loss)

true_class = tf.reshape(tf.argmax(t, axis=1), shape=(-1, 1))
predicted_class = tf.reshape(tf.argmax(y, axis=1), shape=(-1, 1))
self.accuracy_tracker.update_state(true_class, predicted_class)

For the accuracy, we need to pass the class (predicted or ground truth), not the probabilities.

The test_step() method does roughly the same as train_step(), except that it does not modify the parameters: it is called on the validation data in order to compute the metrics. As we do not learn, we do not actually need the tape.

Q: Create the custom CNN model and train it on MNIST. When compiling the model, you only need to pass it the right optimizer, as the loss function and the metrics are already defined in the model. Check that you get the same results as last time.

# Delete all previous models to free memory
tf.keras.backend.clear_session()
    
# Create the custom model
model = CNN()

# Optimizer
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, decay=1e-6, momentum=0.9, nesterov=True)

# Compile
model.compile(
    optimizer=optimizer, # learning rule
)

# Training
history = tf.keras.callbacks.History()
model.fit(
    X_train, T_train,
    batch_size=64, 
    epochs=20,
    validation_split=0.1,
    callbacks=[history]
)

# Testing
score = model.evaluate(X_test, T_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

plt.figure(figsize=(15, 6))

plt.subplot(121)
plt.plot(history.history['loss'], '-r', label="Training")
plt.plot(history.history['val_loss'], '-b', label="Validation")
plt.xlabel('Epoch #')
plt.ylabel('Loss')
plt.legend()

plt.subplot(122)
plt.plot(history.history['accuracy'], '-r', label="Training")
plt.plot(history.history['val_accuracy'], '-b', label="Validation")
plt.xlabel('Epoch #')
plt.ylabel('Accuracy')
plt.legend()

plt.show()
2022-12-13 16:07:41.101602: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-12-13 16:07:41.102693: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 32)        320       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 32)       0         
 )                                                               
                                                                 
 dropout (Dropout)           (None, 13, 13, 32)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 11, 11, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 5, 5, 64)         0         
 2D)                                                             
                                                                 
 dropout_1 (Dropout)         (None, 5, 5, 64)          0         
                                                                 
 flatten (Flatten)           (None, 1600)              0         
                                                                 
 dense (Dense)               (None, 150)               240150    
                                                                 
 dropout_2 (Dropout)         (None, 150)               0         
                                                                 
 dense_1 (Dense)             (None, 10)                1510      
                                                                 
=================================================================
Total params: 260,476
Trainable params: 260,476
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/20
2022-12-13 16:07:41.632132: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2022-12-13 16:07:41.829108: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
844/844 [==============================] - ETA: 0s - loss: 0.4665 - accuracy: 0.8480
2022-12-13 16:07:53.832423: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
844/844 [==============================] - 13s 13ms/step - loss: 0.4665 - accuracy: 0.8480 - val_loss: 0.0944 - val_accuracy: 0.9718
Epoch 2/20
844/844 [==============================] - 11s 13ms/step - loss: 0.1466 - accuracy: 0.9546 - val_loss: 0.0637 - val_accuracy: 0.9822
Epoch 3/20
844/844 [==============================] - 10s 12ms/step - loss: 0.1094 - accuracy: 0.9665 - val_loss: 0.0537 - val_accuracy: 0.9857
Epoch 4/20
844/844 [==============================] - 10s 12ms/step - loss: 0.0907 - accuracy: 0.9716 - val_loss: 0.0470 - val_accuracy: 0.9867
Epoch 5/20
844/844 [==============================] - 10s 12ms/step - loss: 0.0799 - accuracy: 0.9751 - val_loss: 0.0417 - val_accuracy: 0.9882
Epoch 6/20
844/844 [==============================] - 10s 12ms/step - loss: 0.0725 - accuracy: 0.9771 - val_loss: 0.0406 - val_accuracy: 0.9877
Epoch 7/20
844/844 [==============================] - 10s 12ms/step - loss: 0.0682 - accuracy: 0.9786 - val_loss: 0.0367 - val_accuracy: 0.9900
Epoch 8/20
844/844 [==============================] - 10s 12ms/step - loss: 0.0636 - accuracy: 0.9799 - val_loss: 0.0365 - val_accuracy: 0.9890
Epoch 9/20
844/844 [==============================] - 10s 12ms/step - loss: 0.0602 - accuracy: 0.9811 - val_loss: 0.0335 - val_accuracy: 0.9898
Epoch 10/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0581 - accuracy: 0.9815 - val_loss: 0.0339 - val_accuracy: 0.9893
Epoch 11/20
844/844 [==============================] - 10s 12ms/step - loss: 0.0529 - accuracy: 0.9838 - val_loss: 0.0325 - val_accuracy: 0.9902
Epoch 12/20
844/844 [==============================] - 10s 12ms/step - loss: 0.0511 - accuracy: 0.9838 - val_loss: 0.0328 - val_accuracy: 0.9900
Epoch 13/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0497 - accuracy: 0.9840 - val_loss: 0.0304 - val_accuracy: 0.9905
Epoch 14/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0472 - accuracy: 0.9850 - val_loss: 0.0321 - val_accuracy: 0.9908
Epoch 15/20
844/844 [==============================] - 10s 12ms/step - loss: 0.0443 - accuracy: 0.9860 - val_loss: 0.0295 - val_accuracy: 0.9908
Epoch 16/20
844/844 [==============================] - 10s 12ms/step - loss: 0.0438 - accuracy: 0.9870 - val_loss: 0.0297 - val_accuracy: 0.9912
Epoch 17/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0420 - accuracy: 0.9872 - val_loss: 0.0294 - val_accuracy: 0.9923
Epoch 18/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0412 - accuracy: 0.9871 - val_loss: 0.0280 - val_accuracy: 0.9925
Epoch 19/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0395 - accuracy: 0.9870 - val_loss: 0.0274 - val_accuracy: 0.9912
Epoch 20/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0386 - accuracy: 0.9871 - val_loss: 0.0273 - val_accuracy: 0.9905
Test loss: 0.02209572307765484
Test accuracy: 0.9928000569343567

Q: Redefine the model so that it minimizes the mean-square error (t-y)^2 instead of the cross-entropy. What happens?

Hint: squaring a tensor element-wise is done by applying **2 on it just like in numpy.

class CNN(tf.keras.Model):

    def __init__(self):
        super(CNN, self).__init__()

        # Model
        self.model = create_model()

        # Metrics
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.accuracy_tracker = tf.keras.metrics.Accuracy(name="accuracy")

    @property
    def metrics(self):
        "Track the loss and accuracy"
        return [self.loss_tracker, self.accuracy_tracker]
        
    def train_step(self, data):
        # Get the data of the minibatch
        X, t = data
        
        # Use GradientTape to record everything we need to compute the gradient
        with tf.GradientTape() as tape:

            # Prediction using the model
            y = self.model(X, training=True)
            
            # Cross-entropy loss
            loss = tf.reduce_mean(
                tf.reduce_sum(
                    (t - y)**2, # Mean square error
                    axis=1 # First index is the batch size, the second is the classes
                )
            )
        
        # Compute gradients
        grads = tape.gradient(loss, self.trainable_weights)
        
        # Apply gradients using the optimizer
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        # Update metrics 
        self.loss_tracker.update_state(loss)
        true_class = tf.reshape(tf.argmax(t, axis=1), shape=(-1, 1))
        predicted_class = tf.reshape(tf.argmax(y, axis=1), shape=(-1, 1))
        self.accuracy_tracker.update_state(true_class, predicted_class)
        
        # Return a dict mapping metric names to current value
        return {"loss": self.loss_tracker.result(), 'accuracy': self.accuracy_tracker.result()} 

    def test_step(self, data):
        
        # Get data
        X, t = data
        
        # Prediction
        y = self.model(X, training=False)
            
        # Loss
        loss = tf.reduce_mean(
            tf.reduce_sum(
                    (t - y)**2, # Mean square error
                    axis=1
            )
        )
        
        # Update metrics 
        self.loss_tracker.update_state(loss)
        true_class = tf.reshape(tf.argmax(t, axis=1), shape=(-1, 1))
        predicted_class = tf.reshape(tf.argmax(y, axis=1), shape=(-1, 1))
        self.accuracy_tracker.update_state(true_class, predicted_class)
        
        # Return a dict mapping metric names to current value
        return {"loss": self.loss_tracker.result(), 'accuracy': self.accuracy_tracker.result()} 
                
# Delete all previous models to free memory
tf.keras.backend.clear_session()
    
# Create the custom model
model = CNN()

# Optimizer
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, decay=1e-6, momentum=0.9, nesterov=True)

# Compile
model.compile(
    optimizer=optimizer, # learning rule
)

# Training
history = tf.keras.callbacks.History()
model.fit(
    X_train, T_train,
    batch_size=64, 
    epochs=20,
    validation_split=0.1,
    callbacks=[history]
)

# Testing
score = model.evaluate(X_test, T_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

plt.figure(figsize=(15, 6))

plt.subplot(121)
plt.plot(history.history['loss'], '-r', label="Training")
plt.plot(history.history['val_loss'], '-b', label="Validation")
plt.xlabel('Epoch #')
plt.ylabel('Loss')
plt.legend()

plt.subplot(122)
plt.plot(history.history['accuracy'], '-r', label="Training")
plt.plot(history.history['val_accuracy'], '-b', label="Validation")
plt.xlabel('Epoch #')
plt.ylabel('Accuracy')
plt.legend()

plt.show()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 32)        320       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 32)       0         
 )                                                               
                                                                 
 dropout (Dropout)           (None, 13, 13, 32)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 11, 11, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 5, 5, 64)         0         
 2D)                                                             
                                                                 
 dropout_1 (Dropout)         (None, 5, 5, 64)          0         
                                                                 
 flatten (Flatten)           (None, 1600)              0         
                                                                 
 dense (Dense)               (None, 150)               240150    
                                                                 
 dropout_2 (Dropout)         (None, 150)               0         
                                                                 
 dense_1 (Dense)             (None, 10)                1510      
                                                                 
=================================================================
Total params: 260,476
Trainable params: 260,476
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/20
  1/844 [..............................] - ETA: 3:56 - loss: 0.9039 - accuracy: 0.1250
2022-12-13 16:11:06.392387: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
844/844 [==============================] - ETA: 0s - loss: 0.3687 - accuracy: 0.7129
2022-12-13 16:11:15.682565: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
844/844 [==============================] - 10s 11ms/step - loss: 0.3687 - accuracy: 0.7129 - val_loss: 0.0669 - val_accuracy: 0.9598
Epoch 2/20
844/844 [==============================] - 10s 11ms/step - loss: 0.1052 - accuracy: 0.9318 - val_loss: 0.0406 - val_accuracy: 0.9740
Epoch 3/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0752 - accuracy: 0.9510 - val_loss: 0.0323 - val_accuracy: 0.9800
Epoch 4/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0622 - accuracy: 0.9591 - val_loss: 0.0314 - val_accuracy: 0.9795
Epoch 5/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0539 - accuracy: 0.9648 - val_loss: 0.0242 - val_accuracy: 0.9845
Epoch 6/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0475 - accuracy: 0.9689 - val_loss: 0.0232 - val_accuracy: 0.9848
Epoch 7/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0439 - accuracy: 0.9712 - val_loss: 0.0205 - val_accuracy: 0.9865
Epoch 8/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0399 - accuracy: 0.9739 - val_loss: 0.0193 - val_accuracy: 0.9867
Epoch 9/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0379 - accuracy: 0.9758 - val_loss: 0.0184 - val_accuracy: 0.9877
Epoch 10/20
844/844 [==============================] - 10s 12ms/step - loss: 0.0353 - accuracy: 0.9774 - val_loss: 0.0180 - val_accuracy: 0.9877
Epoch 11/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0344 - accuracy: 0.9776 - val_loss: 0.0175 - val_accuracy: 0.9890
Epoch 12/20
844/844 [==============================] - 10s 12ms/step - loss: 0.0338 - accuracy: 0.9779 - val_loss: 0.0173 - val_accuracy: 0.9885
Epoch 13/20
844/844 [==============================] - 9s 11ms/step - loss: 0.0313 - accuracy: 0.9799 - val_loss: 0.0171 - val_accuracy: 0.9885
Epoch 14/20
844/844 [==============================] - 9s 11ms/step - loss: 0.0305 - accuracy: 0.9808 - val_loss: 0.0168 - val_accuracy: 0.9887
Epoch 15/20
844/844 [==============================] - 9s 11ms/step - loss: 0.0286 - accuracy: 0.9820 - val_loss: 0.0156 - val_accuracy: 0.9898
Epoch 16/20
844/844 [==============================] - 9s 11ms/step - loss: 0.0275 - accuracy: 0.9824 - val_loss: 0.0154 - val_accuracy: 0.9895
Epoch 17/20
844/844 [==============================] - 9s 11ms/step - loss: 0.0280 - accuracy: 0.9819 - val_loss: 0.0167 - val_accuracy: 0.9897
Epoch 18/20
844/844 [==============================] - 10s 11ms/step - loss: 0.0267 - accuracy: 0.9831 - val_loss: 0.0155 - val_accuracy: 0.9900
Epoch 19/20
844/844 [==============================] - 9s 11ms/step - loss: 0.0274 - accuracy: 0.9821 - val_loss: 0.0160 - val_accuracy: 0.9888
Epoch 20/20
844/844 [==============================] - 9s 11ms/step - loss: 0.0251 - accuracy: 0.9838 - val_loss: 0.0152 - val_accuracy: 0.9902
Test loss: 0.014269283041357994
Test accuracy: 0.9901000261306763

A: Nothing, it also works… Only the loss has different values.

Custom layers

Keras layers take a tensor as input (the output of the previous layer on a minibatch) and transform it into another tensor, possibly using trainable parameters. As we have seen, tensorflow allows to manipulate tensors and apply differentiable operations on them, so we could redefine the function made by a keras layer using tensorflow operations.

The following cell shows how to implement a dummy layer that takes a tensor T as input (the first dimension is the batch size) and returns the tensor \exp - \lambda \, T, \lambda being a fixed parameter.

class ExponentialLayer(tf.keras.layers.Layer):
    """Layer performing element-wise exponentiation."""

    def __init__(self, factor=1.0):
        super(ExponentialLayer, self).__init__()
        self.factor = factor

    def call(self, inputs):
        return tf.exp(- self.factor*inputs)

ExponentialLayer inherits from tf.keras.layers.Layer and redefines the call() method that defines the forward pass. Here we simply return the corresponding tensor.

The layer can then be used in a functional model directly:

x = ExponentialLayer(factor=1.0)(x)

As we use tensorflow operators, it knows how to differentiate it when applying backpropagation.

More information on how to create new layers can be found at https://keras.io/guides/making_new_layers_and_models_via_subclassing. FYI, this is how you would redefine a fully-connected layer without an activation function, using a trainable weight matrix and bias vector:

class Linear(tf.keras.layers.Layer):
    def __init__(self, units=32):
        "Number of neurons in the layer."
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        "Create the weight matrix and bias vector once we know the shape of the previous layer."
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        "Return W*X + b"
        return tf.matmul(inputs, self.w) + self.b

Q: Add the exponential layer to the CNN between the last FC layer and the output layer. Change the value of the parameter. Does it still work?

# Model
inputs = tf.keras.layers.Input((28, 28, 1))

x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='valid')(inputs)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
x = tf.keras.layers.Dropout(0.5)(x)

x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='valid')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
x = tf.keras.layers.Dropout(0.5)(x)

x = tf.keras.layers.Flatten()(x)

x = tf.keras.layers.Dense(150, activation='relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)

x = ExponentialLayer(factor=1.0)(x)

outputs = tf.keras.layers.Dense(10, activation='softmax')(x)

model = tf.keras.Model(inputs, outputs)
print(model.summary())

# Optimizer
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, decay=1e-6, momentum=0.9, nesterov=True)

# Compile
model.compile(
    loss="categorical_crossentropy",
    optimizer=optimizer, # learning rule
    metrics=["accuracy"]
)

# Training
history = tf.keras.callbacks.History()
model.fit(
    X_train, T_train,
    batch_size=64, 
    epochs=20,
    validation_split=0.1,
    callbacks=[history]
)

# Testing
score = model.evaluate(X_test, T_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])


plt.figure(figsize=(15, 6))

plt.subplot(121)
plt.plot(history.history['loss'], '-r', label="Training")
plt.plot(history.history['val_loss'], '-b', label="Validation")
plt.xlabel('Epoch #')
plt.ylabel('Loss')
plt.legend()

plt.subplot(122)
plt.plot(history.history['accuracy'], '-r', label="Training")
plt.plot(history.history['val_accuracy'], '-b', label="Validation")
plt.xlabel('Epoch #')
plt.ylabel('Accuracy')
plt.legend()

plt.show()
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d_2 (Conv2D)           (None, 26, 26, 32)        320       
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 13, 13, 32)       0         
 2D)                                                             
                                                                 
 dropout_3 (Dropout)         (None, 13, 13, 32)        0         
                                                                 
 conv2d_3 (Conv2D)           (None, 11, 11, 64)        18496     
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 5, 5, 64)         0         
 2D)                                                             
                                                                 
 dropout_4 (Dropout)         (None, 5, 5, 64)          0         
                                                                 
 flatten_1 (Flatten)         (None, 1600)              0         
                                                                 
 dense_2 (Dense)             (None, 150)               240150    
                                                                 
 dropout_5 (Dropout)         (None, 150)               0         
                                                                 
 exponential_layer (Exponent  (None, 150)              0         
 ialLayer)                                                       
                                                                 
 dense_3 (Dense)             (None, 10)                1510      
                                                                 
=================================================================
Total params: 260,476
Trainable params: 260,476
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/20
  1/844 [..............................] - ETA: 4:06 - loss: 3.2513 - accuracy: 0.0938
2022-12-13 16:14:20.089533: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
844/844 [==============================] - ETA: 0s - loss: 0.7619 - accuracy: 0.7816
2022-12-13 16:14:32.086391: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
844/844 [==============================] - 13s 15ms/step - loss: 0.7619 - accuracy: 0.7816 - val_loss: 0.1412 - val_accuracy: 0.9585
Epoch 2/20
844/844 [==============================] - 13s 15ms/step - loss: 0.2267 - accuracy: 0.9391 - val_loss: 0.0902 - val_accuracy: 0.9723
Epoch 3/20
844/844 [==============================] - 13s 15ms/step - loss: 0.1568 - accuracy: 0.9566 - val_loss: 0.0695 - val_accuracy: 0.9798
Epoch 4/20
844/844 [==============================] - 13s 15ms/step - loss: 0.1267 - accuracy: 0.9646 - val_loss: 0.0653 - val_accuracy: 0.9817
Epoch 5/20
844/844 [==============================] - 13s 15ms/step - loss: 0.1100 - accuracy: 0.9676 - val_loss: 0.0544 - val_accuracy: 0.9830
Epoch 6/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0979 - accuracy: 0.9719 - val_loss: 0.0498 - val_accuracy: 0.9865
Epoch 7/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0881 - accuracy: 0.9744 - val_loss: 0.0514 - val_accuracy: 0.9858
Epoch 8/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0812 - accuracy: 0.9763 - val_loss: 0.0481 - val_accuracy: 0.9870
Epoch 9/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0766 - accuracy: 0.9781 - val_loss: 0.0474 - val_accuracy: 0.9873
Epoch 10/20
844/844 [==============================] - 13s 15ms/step - loss: 0.0726 - accuracy: 0.9786 - val_loss: 0.0447 - val_accuracy: 0.9885
Epoch 11/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0698 - accuracy: 0.9795 - val_loss: 0.0436 - val_accuracy: 0.9892
Epoch 12/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0649 - accuracy: 0.9807 - val_loss: 0.0447 - val_accuracy: 0.9888
Epoch 13/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0651 - accuracy: 0.9803 - val_loss: 0.0438 - val_accuracy: 0.9895
Epoch 14/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0622 - accuracy: 0.9816 - val_loss: 0.0445 - val_accuracy: 0.9890
Epoch 15/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0583 - accuracy: 0.9822 - val_loss: 0.0444 - val_accuracy: 0.9893
Epoch 16/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0572 - accuracy: 0.9832 - val_loss: 0.0418 - val_accuracy: 0.9893
Epoch 17/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0559 - accuracy: 0.9836 - val_loss: 0.0425 - val_accuracy: 0.9903
Epoch 18/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0536 - accuracy: 0.9836 - val_loss: 0.0432 - val_accuracy: 0.9905
Epoch 19/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0526 - accuracy: 0.9839 - val_loss: 0.0437 - val_accuracy: 0.9900
Epoch 20/20
844/844 [==============================] - 12s 15ms/step - loss: 0.0509 - accuracy: 0.9843 - val_loss: 0.0396 - val_accuracy: 0.9910
Test loss: 0.03273899853229523
Test accuracy: 0.9900000691413879

A: Surprisingly, it still works, unless you pick a high value for the parameter. The exponential layer only outputs positive values, but that is enough information for the output layer to do its job.

Variational autoencoder

We are now ready to implement the VAE. We are going to redefine the training set, as we want pixel values to be between 0 and 1 (so that we can compute a cross-entropy). Therefore, we do not perform removal:

# Fetch the MNIST data
(X_train, t_train), (X_test, t_test) = tf.keras.datasets.mnist.load_data()
print("Training data:", X_train.shape, t_train.shape)
print("Test data:", X_test.shape, t_test.shape)

# Normalize the values
X_train = X_train.reshape(-1, 28, 28, 1).astype('float32') / 255.
X_test = X_test.reshape(-1, 28, 28, 1).astype('float32') / 255.

# One-hot encoding
T_train = tf.keras.utils.to_categorical(t_train, 10)
T_test = tf.keras.utils.to_categorical(t_test, 10)
Training data: (60000, 28, 28) (60000,)
Test data: (10000, 28, 28) (10000,)

Encoder

The encoder can have any form, the only constraint is that is takes an input (28, 28, 1) and outputs two vectors \mu and \log(\sigma) of size latent_dim, the parameters of the normal distribution representing the input. We are going to use only latent_dim=2 latent dimensions, but let’s make the code generic.

For a network to have two outputs, one just needs to use the functional API to create the graph:

# Previous layer
x = tf.keras.layers.Dense(N, activation="relu")(x)

# First output takes input from x
z_mean = tf.keras.layers.Dense(latent_dim)(x)

# Second output also takes input from x  
z_log_var = tf.keras.layers.Dense(latent_dim)(x)

This would not be possible using the Sequential API, but is straightforward using the functional one, as you decide from where a layer takes its inputs.

What we still need and is not standard in keras is a sampling layer that implements the reparameterization trick:

\mathbf{z} = \mu + \sigma \, \xi

where \xi comes from the standard normal distribution \mathcal{N}(0, 1).

For technical reasons, it is actually better when z_log_var represents \log \sigma^2 instead of \sigma, as it can take both positive and negative values, while \sigma could only be strictly positive.

\text{z\_log\_var} = \log \, \sigma^2

\sigma = \exp \dfrac{\text{z\_log\_var}}{2}

We therefore want a layer that computes:

z = z_mean + tf.math.exp(0.5 * z_log_var) * xi

on the tensors of shape (batch_size, latent_dim). To sample the standard normal distribution, you can use tensorflow:

xi = tf.random.normal(shape=(batch_size, latent_dim) mean=0.0, stddev=1.0)

Q: Create a custom SamplingLayer layer that takes inputs from z_mean and z_log_var, being called like this:

z = SamplingLayer()([z_mean, z_log_var])

In order to get each input separately, the inputs argument can be split:

def call(self, inputs):
    z_mean, z_log_var = inputs

The only difficulty is to pass the correct dimensions to xi, as you do not know the batch size yet. You can retrieve it using the shape of z_mean:

batch_size = tf.shape(z_mean)[0]
latent_dim = tf.shape(z_mean)[1]
class SamplingLayer(tf.keras.layers.Layer):
    """Uses (z_mean, z_log_var) to sample z."""

    def __init__(self):
        super(SamplingLayer, self).__init__()

    def call(self, inputs):
        # Retrieve inputs mu and 2*log(sigma)
        z_mean, z_log_var = inputs

        # Batch size and latent dimension
        batch_size = tf.shape(z_mean)[0]
        latent_dim = tf.shape(z_mean)[1]
        
        # Random variable from the standard normal distribution
        xi = tf.random.normal(shape=(batch_size, latent_dim), mean=0.0, stddev=1.0)
        
        # Reparameterization trick
        return z_mean + tf.math.exp(0.5 * z_log_var) * xi

We can now create the encoder in a create_encoder(latent_dim) method that return an uncompiled model.

You can put what you want in the encoder as long as it takes a (28, 28, 1) input and returns the three layers [z_mean, z_log_var, z] (we need z_mean and z_log_var to define the loss, normally you only need z):

def create_encoder(latent_dim):

    inputs = tf.keras.layers.Input(shape=(28, 28, 1))
    
    # Stuff, with x being the last FC layer

    z_mean = tf.keras.layers.Dense(latent_dim)(x)
    
    z_log_var = tf.keras.layers.Dense(latent_dim)(x)
    
    z = SamplingLayer()([z_mean, z_log_var])

    model = tf.keras.Model(inputs, [z_mean, z_log_var, z])
    
    print(model.summary())

    return model

One suggestion would be to use two convolutional layers with a stride of 2 (replacing max-pooling) and one fully-connected layer with enough neurons, but you do what you want.

Q: Create the encoder.

def create_encoder(latent_dim):

    inputs = tf.keras.layers.Input(shape=(28, 28, 1))
    
    x = tf.keras.layers.Conv2D(32, (3, 3), strides=2, activation='relu', padding='valid')(inputs)

    x = tf.keras.layers.Conv2D(64, (3, 3), strides=2, activation='relu', padding='valid')(x)

    x = tf.keras.layers.Flatten()(x)

    x = tf.keras.layers.Dense(16, activation="relu")(x)

    z_mean = tf.keras.layers.Dense(latent_dim)(x)
    
    z_log_var = tf.keras.layers.Dense(latent_dim)(x)
    
    z = SamplingLayer()([z_mean, z_log_var])

    model = tf.keras.Model(inputs, [z_mean, z_log_var, z])
    
    print(model.summary())

    return model

The decoder is a bit more tricky. It takes the vector z as an input (latent_dim=2 dimensions) and should output an image (28, 28, 1) with pixels normailzed between 0 and 1. The output layer should therefore use the 'sigmoid' transfer function:

def create_decoder(latent_dim):
    
    inputs = tf.keras.layers.Input(shape=(latent_dim,))

    # Stuff, with x being a transposed convolution layer of shape (28, 28, N)
    
    outputs = tf.keras.layers.Conv2DTranspose(1, (3, 3), activation="sigmoid", padding="same")(x)
    
    model = tf.keras.Model(inputs, outputs)
    print(model.summary())
    
    return model

The decoder has to use transposed convolutions to upsample the tensors instead of downsampling them. Check the doc of Conv2DTranspose at https://keras.io/api/layers/convolution_layers/convolution2d_transpose/.

In order to build the decoder, you have to be careful when it comes to tensor shapes: the output must be exactly (28, 28, 1), not (26, 26, 1), otherwise you will not be able to compute the reconstruction loss. You need to be careful with the stride (upsampling ratio) and padding method (‘same’ or ‘valid’) of the layers you add. Do not hesitate to create dummy models and print their summary to see the shapes.

Another trick is that you need to transform the vector z with latent_dim=2 elements into a 3D tensor before applying transposed convolutions (i.e. the inverse of Flatten()). If you for example want a tensor of shape (7, 7, 64) as the input to the first transposed convolution, you could project the vector to a fully connected layer with 7*7*64 neurons:

x = tf.keras.layers.Dense(7 * 7 * 64, activation="relu")(inputs)

and reshape it to a (7, 7, 64) tensor:

x = tf.keras.layers.Reshape((7, 7, 64))(x)

Q: Create the decoder.

def create_decoder(latent_dim):
    
    inputs = tf.keras.layers.Input(shape=(latent_dim,))

    x = tf.keras.layers.Dense(7 * 7 * 64, activation="relu")(inputs)
    x = tf.keras.layers.Reshape((7, 7, 64))(x)

    x = tf.keras.layers.Conv2DTranspose(64, (3, 3), activation="relu", strides=2, padding="same")(x)
    x = tf.keras.layers.Conv2DTranspose(32, (3, 3), activation="relu", strides=2, padding="same")(x)
    
    outputs = tf.keras.layers.Conv2DTranspose(1, (3, 3), activation="sigmoid", padding="same")(x)
    
    model = tf.keras.Model(inputs, outputs)
    print(model.summary())
    
    return model

Q: Create a custom VAE model (inheriting from tf.keras.Model) that:

  • takes the latent dimension as argument:
vae = VAE(latent_dim=2)
  • creates the encoder and decoder in the constructor.

  • tracks the reconstruction and KL losses as metrics.

  • does not use validation data (i.e., do not implement test_step() and do not provide any validation data to fit()).

  • computes the reconstruction loss using binary cross-entropy over all pixels of the reconstructed image:

\mathcal{L}_\text{reconstruction}(\theta) = \frac{1}{N} \sum_{i=1}^N \sum_{w, h \in \text{pixels}} - t^i(w, h) \, \log y^i(w, h) - (1 - t^i(w, h)) \, \log(1 - y^i(w, h))

where t^i(w, h) is the pixel of coordinates (w, h) (between 0 and 27) of the i-th image of the minibatch.

  • computes the KL divergence loss for the encoder:

\mathcal{L}_\text{KL}(\theta) = \frac{1}{2 N} \sum_{i=1}^N (\exp(\text{z\_log\_var}^i) + (\text{z\_mean}^i)^2 - 1 - \text{z\_log\_var}^i)

  • minimizes the total loss:

\mathcal{L}(\theta) = \mathcal{L}_\text{reconstruction}(\theta) + \mathcal{L}_\text{KL}(\theta)

Train it on the MNIST images for 30 epochs (or more) with the right batch size and a good optimizer (history = vae.fit(X_train, X_train, epochs=30, batch_size=b)). How do the losses evolve?

Hint: for the reconstruction loss, you can implement the formula using tensorflow operations, or call tf.keras.losses.binary_crossentropy(t, y) directly.

Do not worry if your reconstruction loss does not go to zero, but stays in the hundreds, it is normal. Use the next cell to visualize the reconstructions.

Note: The KL is expressed for a single sample as:

\mathcal{L}_\text{KL}(\theta) = \dfrac{1}{2} \, (\mathbf{\sigma_x^2} + \mathbf{\mu_x}^2 - 1 - \log \mathbf{\sigma_x^2})

With \text{z\_log\_var} = \log \, \sigma^2 or \sigma = \exp \dfrac{\text{z\_log\_var}}{2}, this becomes:

\mathcal{L}_\text{KL}(\theta) = \dfrac{1}{2} \, (\exp \text{z\_log\_var} + \mathbf{\mu_x}^2 - 1 - \text{z\_log\_var})

class VAE(tf.keras.Model):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = create_encoder(latent_dim)

        # Decoder
        self.decoder = create_decoder(latent_dim)
        
        # Track losses
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")

    def train_step(self, data):

        with tf.GradientTape() as tape:

            # Data: input = output
            X, t = data

            # Encoder
            z_mean, z_log_var, z = self.encoder(X)
            
            # Decoder
            y = self.decoder(z)
            
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                   #- t * tf.math.log(y) - (1. - t) * tf.math.log(1. - y), 
                   tf.keras.losses.binary_crossentropy(t, y),
                    axis=(1, 2)
                )
            )

            kl_loss = tf.reduce_mean(
                tf.reduce_sum(
                    0.5 * (tf.exp(z_log_var) + tf.square(z_mean) - 1 - z_log_var ), 
                    axis=1
                )
            )
            
            total_loss = reconstruction_loss + kl_loss
        
        grads = tape.gradient(total_loss, self.trainable_weights)
        
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }


    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]
# Delete all previous models to free memory
tf.keras.backend.clear_session()

# Create the VAE with 2 latent variables
vae = VAE(latent_dim=2)

# Optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)

# Compile
vae.compile(optimizer=optimizer)

# Train the VAE
history = vae.fit(X_train, X_train, epochs=30, batch_size=128)
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 13, 13, 32)   320         ['input_1[0][0]']                
                                                                                                  
 conv2d_1 (Conv2D)              (None, 6, 6, 64)     18496       ['conv2d[0][0]']                 
                                                                                                  
 flatten (Flatten)              (None, 2304)         0           ['conv2d_1[0][0]']               
                                                                                                  
 dense (Dense)                  (None, 16)           36880       ['flatten[0][0]']                
                                                                                                  
 dense_1 (Dense)                (None, 10)           170         ['dense[0][0]']                  
                                                                                                  
 dense_2 (Dense)                (None, 10)           170         ['dense[0][0]']                  
                                                                                                  
 sampling_layer (SamplingLayer)  (None, 10)          0           ['dense_1[0][0]',                
                                                                  'dense_2[0][0]']                
                                                                                                  
==================================================================================================
Total params: 56,036
Trainable params: 56,036
Non-trainable params: 0
__________________________________________________________________________________________________
None
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 10)]              0         
                                                                 
 dense_3 (Dense)             (None, 3136)              34496     
                                                                 
 reshape (Reshape)           (None, 7, 7, 64)          0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 14, 14, 64)       36928     
 nspose)                                                         
                                                                 
 conv2d_transpose_1 (Conv2DT  (None, 28, 28, 32)       18464     
 ranspose)                                                       
                                                                 
 conv2d_transpose_2 (Conv2DT  (None, 28, 28, 1)        289       
 ranspose)                                                       
                                                                 
=================================================================
Total params: 90,177
Trainable params: 90,177
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/30
2022-12-14 10:49:27.940751: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
469/469 [==============================] - 11s 21ms/step - loss: 389.6485 - reconstruction_loss: 283.3474 - kl_loss: 7.9774
Epoch 2/30
469/469 [==============================] - 9s 20ms/step - loss: 213.7090 - reconstruction_loss: 206.4947 - kl_loss: 4.7493
Epoch 3/30
469/469 [==============================] - 9s 20ms/step - loss: 205.4718 - reconstruction_loss: 199.8490 - kl_loss: 3.9704
Epoch 4/30
469/469 [==============================] - 9s 20ms/step - loss: 198.6427 - reconstruction_loss: 190.5134 - kl_loss: 5.2328
Epoch 5/30
469/469 [==============================] - 9s 20ms/step - loss: 181.9525 - reconstruction_loss: 167.7927 - kl_loss: 8.7311
Epoch 6/30
469/469 [==============================] - 9s 20ms/step - loss: 164.0484 - reconstruction_loss: 149.6027 - kl_loss: 10.5091
Epoch 7/30
469/469 [==============================] - 9s 20ms/step - loss: 153.4709 - reconstruction_loss: 141.1168 - kl_loss: 11.2486
Epoch 8/30
469/469 [==============================] - 9s 20ms/step - loss: 149.5024 - reconstruction_loss: 137.3778 - kl_loss: 11.6093
Epoch 9/30
469/469 [==============================] - 9s 20ms/step - loss: 146.2169 - reconstruction_loss: 132.1913 - kl_loss: 12.4981
Epoch 10/30
469/469 [==============================] - 9s 20ms/step - loss: 141.2454 - reconstruction_loss: 127.1392 - kl_loss: 13.5203
Epoch 11/30
469/469 [==============================] - 9s 20ms/step - loss: 137.9547 - reconstruction_loss: 122.6539 - kl_loss: 14.6289
Epoch 12/30
469/469 [==============================] - 9s 20ms/step - loss: 135.0639 - reconstruction_loss: 119.6763 - kl_loss: 15.1527
Epoch 13/30
469/469 [==============================] - 9s 20ms/step - loss: 133.2559 - reconstruction_loss: 117.9188 - kl_loss: 15.2583
Epoch 14/30
469/469 [==============================] - 9s 20ms/step - loss: 132.2416 - reconstruction_loss: 116.6409 - kl_loss: 15.3469
Epoch 15/30
469/469 [==============================] - 9s 20ms/step - loss: 131.1907 - reconstruction_loss: 115.5849 - kl_loss: 15.3408
Epoch 16/30
469/469 [==============================] - 9s 20ms/step - loss: 130.0235 - reconstruction_loss: 114.6940 - kl_loss: 15.3268
Epoch 17/30
469/469 [==============================] - 9s 20ms/step - loss: 129.3353 - reconstruction_loss: 113.9605 - kl_loss: 15.3060
Epoch 18/30
469/469 [==============================] - 9s 20ms/step - loss: 128.8334 - reconstruction_loss: 113.2801 - kl_loss: 15.2965
Epoch 19/30
469/469 [==============================] - 9s 20ms/step - loss: 128.3190 - reconstruction_loss: 112.7102 - kl_loss: 15.2659
Epoch 20/30
469/469 [==============================] - 9s 20ms/step - loss: 127.6685 - reconstruction_loss: 112.1943 - kl_loss: 15.2335
Epoch 21/30
469/469 [==============================] - 9s 20ms/step - loss: 127.2754 - reconstruction_loss: 111.7160 - kl_loss: 15.2299
Epoch 22/30
469/469 [==============================] - 9s 20ms/step - loss: 126.3912 - reconstruction_loss: 111.2648 - kl_loss: 15.1921
Epoch 23/30
469/469 [==============================] - 9s 20ms/step - loss: 126.1763 - reconstruction_loss: 110.8554 - kl_loss: 15.1548
Epoch 24/30
469/469 [==============================] - 9s 20ms/step - loss: 125.7592 - reconstruction_loss: 110.5094 - kl_loss: 15.1186
Epoch 25/30
469/469 [==============================] - 9s 20ms/step - loss: 125.3898 - reconstruction_loss: 110.1903 - kl_loss: 15.1253
Epoch 26/30
469/469 [==============================] - 185s 394ms/step - loss: 124.7847 - reconstruction_loss: 109.8150 - kl_loss: 15.0696
Epoch 27/30
469/469 [==============================] - 9s 20ms/step - loss: 124.7143 - reconstruction_loss: 109.5486 - kl_loss: 15.0504
Epoch 28/30
469/469 [==============================] - 9s 20ms/step - loss: 124.3730 - reconstruction_loss: 109.2505 - kl_loss: 15.0166
Epoch 29/30
469/469 [==============================] - 10s 20ms/step - loss: 123.6702 - reconstruction_loss: 108.9420 - kl_loss: 14.9751
Epoch 30/30
469/469 [==============================] - 9s 20ms/step - loss: 123.7484 - reconstruction_loss: 108.7186 - kl_loss: 14.9486
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], '-r', label="Total loss")
plt.plot(history.history['reconstruction_loss'], '-b', label="Reconstruction loss")
plt.plot(history.history['kl_loss'], '-g', label="KL loss")
plt.xlabel('Epoch #')
plt.ylabel('Loss')
plt.legend()
plt.show()

Q: The following cell allows to regularly sample the latent space and reconstruct the images. It makes the assumption that the decoder is stored at vae.decoder, adapt it otherwise. Comment on the generated samples. Observe in particular the smooth transitions between similar digits.

def plot_latent_space(vae, n=30, figsize=15):
    # display a n*n 2D manifold of digits
    digit_size = 28
    scale = 2.0
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent_space(vae)
2022-12-13 16:28:51.901012: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.

Q: The following cell visualizes the latent representation for the training data, using different colors for the digits. What do you think?

def plot_label_clusters(vae, data, labels):
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = vae.encoder.predict(data)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()


plot_label_clusters(vae, X_train, t_train)
2022-12-13 16:29:08.762616: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.

A: Without having been instructed to, the encoder already separates quite well the different classes of digits in the latent space. A shallow classifier on the latent space with 2 dimensions might be able to classify the MNIST data!

Here, we used labelled data to train the autoencoder, but the use-case would be semi-supervised learning: train the autoencoder on unsupervised unlabelled data, and then train a classifier on its latent space using a small amount of labelled data.