{"id":2370,"date":"2019-07-16T19:00:24","date_gmt":"2019-07-16T19:00:24","guid":{"rendered":"https:\/\/www.aiproblog.com\/index.php\/2019\/07\/16\/how-to-develop-a-wasserstein-generative-adversarial-network-wgan-from-scratch\/"},"modified":"2019-07-16T19:00:24","modified_gmt":"2019-07-16T19:00:24","slug":"how-to-develop-a-wasserstein-generative-adversarial-network-wgan-from-scratch","status":"publish","type":"post","link":"https:\/\/www.aiproblog.com\/index.php\/2019\/07\/16\/how-to-develop-a-wasserstein-generative-adversarial-network-wgan-from-scratch\/","title":{"rendered":"How to Develop a Wasserstein Generative Adversarial Network (WGAN) From Scratch"},"content":{"rendered":"<p>Author: Jason Brownlee<\/p>\n<div>\n<p>The Wasserstein Generative Adversarial Network, or Wasserstein GAN, is an extension to the generative adversarial network that both improves the stability when training the model and provides a loss function that correlates with the quality of generated images.<\/p>\n<p>The development of the WGAN has a dense mathematical motivation, although in practice requires only a few minor modifications to the established standard deep convolutional generative adversarial network, or DCGAN.<\/p>\n<p>In this tutorial, you will discover how to implement the Wasserstein generative adversarial network from scratch.<\/p>\n<p>After completing this tutorial, you will know:<\/p>\n<ul>\n<li>The differences between the standard deep convolutional GAN and the new Wasserstein GAN.<\/li>\n<li>How to implement the specific details of the Wasserstein GAN from scratch.<\/li>\n<li>How to develop a WGAN for image generation and interpret the dynamic behavior of the model.<\/li>\n<\/ul>\n<p>Discover how to develop DCGANs, conditional GANs, Pix2Pix, CycleGANs, and more with Keras <a href=\"https:\/\/machinelearningmastery.com\/generative_adversarial_networks\/\" rel=\"nofollow\">in my new GANs book<\/a>, with 29 step-by-step tutorials and full source code.<\/p>\n<p>Let\u2019s get started.<\/p>\n<div id=\"attachment_8226\" style=\"width: 650px\" class=\"wp-caption aligncenter\"><img loading=\"lazy\" decoding=\"async\" aria-describedby=\"caption-attachment-8226\" class=\"size-full wp-image-8226\" src=\"https:\/\/machinelearningmastery.com\/wp-content\/uploads\/2019\/07\/How-to-Code-a-Wasserstein-Generative-Adversarial-Network-WGAN-From-Scratch.jpg\" alt=\"How to Code a Wasserstein Generative Adversarial Network (WGAN) From Scratch\" width=\"640\" height=\"426\" srcset=\"http:\/\/3qeqpr26caki16dnhd19sv6by6v.wpengine.netdna-cdn.com\/wp-content\/uploads\/2019\/07\/How-to-Code-a-Wasserstein-Generative-Adversarial-Network-WGAN-From-Scratch.jpg 640w, http:\/\/3qeqpr26caki16dnhd19sv6by6v.wpengine.netdna-cdn.com\/wp-content\/uploads\/2019\/07\/How-to-Code-a-Wasserstein-Generative-Adversarial-Network-WGAN-From-Scratch-300x200.jpg 300w\" sizes=\"(max-width: 640px) 100vw, 640px\"><\/p>\n<p id=\"caption-attachment-8226\" class=\"wp-caption-text\">How to Code a Wasserstein Generative Adversarial Network (WGAN) From Scratch<br \/>Photo by <a href=\"https:\/\/www.flickr.com\/photos\/40145521@N00\/2316564611\">Feliciano Guimar\u00e3es<\/a>, some rights reserved.<\/p>\n<\/div>\n<h2>Tutorial Overview<\/h2>\n<p>This tutorial is divided into three parts; they are:<\/p>\n<ol>\n<li>Wasserstein Generative Adversarial Network<\/li>\n<li>Wasserstein GAN Implementation Details<\/li>\n<li>How to Train a Wasserstein GAN Model<\/li>\n<\/ol>\n<h2>Wasserstein Generative Adversarial Network<\/h2>\n<p>The Wasserstein GAN, or WGAN for short, was introduced by Martin Arjovsky, et al. in their 2017 paper titled \u201c<a href=\"https:\/\/arxiv.org\/abs\/1701.07875\">Wasserstein GAN<\/a>.\u201d<\/p>\n<p>It is an extension of the GAN that seeks an alternate way of training the generator model to better approximate the distribution of data observed in a given training dataset.<\/p>\n<p>Instead of using a discriminator to classify or predict the probability of generated images as being real or fake, the WGAN changes or replaces the discriminator model with a critic that scores the realness or fakeness of a given image.<\/p>\n<p>This change is motivated by a theoretical argument that training the generator should seek a minimization of the distance between the distribution of the data observed in the training dataset and the distribution observed in generated examples.<\/p>\n<p>The benefit of the WGAN is that the training process is more stable and less sensitive to model architecture and choice of hyperparameter configurations. Perhaps most importantly, the loss of the discriminator appears to relate to the quality of images created by the generator.<\/p>\n<h2>Wasserstein GAN Implementation Details<\/h2>\n<p>Although the theoretical grounding for the WGAN is dense, the implementation of a WGAN requires a few minor changes to the standard Deep Convolutional GAN, or DCGAN.<\/p>\n<p>The image below provides a summary of the main training loop for training a WGAN, taken from the paper. Note the listing of recommended hyperparameters used in the model.<\/p>\n<div id=\"attachment_8222\" style=\"width: 650px\" class=\"wp-caption aligncenter\"><img loading=\"lazy\" decoding=\"async\" aria-describedby=\"caption-attachment-8222\" class=\"size-full wp-image-8222\" src=\"https:\/\/machinelearningmastery.com\/wp-content\/uploads\/2019\/05\/Algorithm-for-the-Wasserstein-Generative-Adversarial-Networks-1.png\" alt=\"Algorithm for the Wasserstein Generative Adversarial Networks\" width=\"640\" height=\"372\" srcset=\"http:\/\/3qeqpr26caki16dnhd19sv6by6v.wpengine.netdna-cdn.com\/wp-content\/uploads\/2019\/05\/Algorithm-for-the-Wasserstein-Generative-Adversarial-Networks-1.png 640w, http:\/\/3qeqpr26caki16dnhd19sv6by6v.wpengine.netdna-cdn.com\/wp-content\/uploads\/2019\/05\/Algorithm-for-the-Wasserstein-Generative-Adversarial-Networks-1-300x174.png 300w\" sizes=\"(max-width: 640px) 100vw, 640px\"><\/p>\n<p id=\"caption-attachment-8222\" class=\"wp-caption-text\">Algorithm for the Wasserstein Generative Adversarial Networks.<br \/>Taken from: Wasserstein GAN.<\/p>\n<\/div>\n<p>The differences in implementation for the WGAN are as follows:<\/p>\n<ol>\n<li>Use a linear activation function in the output layer of the critic model (instead of sigmoid).<\/li>\n<li>Use -1 labels for real images and 1 labels for fake images (instead of 1 and 0).<\/li>\n<li>Use Wasserstein loss to train the critic and generator models.<\/li>\n<li>Constrain critic model weights to a limited range after each mini batch update (e.g. [-0.01,0.01]).<\/li>\n<li>Update the critic model more times than the generator each iteration (e.g. 5).<\/li>\n<li>Use the RMSProp version of gradient descent with a small learning rate and no momentum (e.g. 0.00005).<\/li>\n<\/ol>\n<p>Using the standard DCGAN model as a starting point, let\u2019s take a look at each of these implementation details in turn.<\/p>\n<div class=\"woo-sc-hr\"><\/div>\n<p><center><\/p>\n<h3>Want to Develop GANs from Scratch?<\/h3>\n<p>Take my free 7-day email crash course now (with sample code).<\/p>\n<p>Click to sign-up and also get a free PDF Ebook version of the course.<\/p>\n<p><a href=\"https:\/\/machinelearningmastery.lpages.co\/leadbox\/162526e1b172a2%3A164f8be4f346dc\/5926953912500224\/\" target=\"_blank\" style=\"background: rgb(255, 206, 10); color: rgb(255, 255, 255); text-decoration: none; font-family: Helvetica, Arial, sans-serif; font-weight: bold; font-size: 16px; line-height: 20px; padding: 10px; display: inline-block; max-width: 300px; border-radius: 5px; text-shadow: rgba(0, 0, 0, 0.25) 0px -1px 1px; box-shadow: rgba(255, 255, 255, 0.5) 0px 1px 3px inset, rgba(0, 0, 0, 0.5) 0px 1px 3px;\" rel=\"noopener noreferrer\">Download Your FREE Mini-Course<\/a><script data-leadbox=\"162526e1b172a2:164f8be4f346dc\" data-url=\"https:\/\/machinelearningmastery.lpages.co\/leadbox\/162526e1b172a2%3A164f8be4f346dc\/5926953912500224\/\" data-config=\"%7B%7D\" type=\"text\/javascript\" src=\"https:\/\/machinelearningmastery.lpages.co\/leadbox-1562872266.js\"><\/script><\/p>\n<p><\/center><\/p>\n<div class=\"woo-sc-hr\"><\/div>\n<h3>1. Linear Activation in Critic Output Layer<\/h3>\n<p>The DCGAN uses the sigmoid activation function in the output layer of the discriminator to predict the likelihood of a given image being real.<\/p>\n<p>In the WGAN, the critic model requires a linear activation to predict the score of \u201c<em>realness<\/em>\u201d for a given image.<\/p>\n<p>This can be achieved by setting the \u2018<em>activation<\/em>\u2018 argument to \u2018<em>linear<\/em>\u2018 in the output layer of the critic model.<\/p>\n<pre class=\"crayon-plain-tag\"># define output layer of the critic model\r\n...\r\nmodel.add(Dense(1, activation='linear'))<\/pre>\n<p>The linear activation is the default activation for a layer, so we can, in fact, leave the activation unspecified to achieve the same result.<\/p>\n<pre class=\"crayon-plain-tag\"># define output layer of the critic model\r\n...\r\nmodel.add(Dense(1))<\/pre>\n<\/p>\n<h3>2. Class Labels for Real and Fake Images<\/h3>\n<p>The DCGAN uses the class 0 for fake images and class 1 for real images, and these class labels are used to train the GAN.<\/p>\n<p>In the DCGAN, these are precise labels that the discriminator is expected to achieve. The WGAN does not have precise labels for the critic. Instead, it encourages the critic to output scores that are different for real and fake images.<\/p>\n<p>This is achieved via the Wasserstein function that cleverly makes use of positive and negative class labels.<\/p>\n<p>The WGAN can be implemented where -1 class labels are used for real images and +1 class labels are used for fake or generated images.<\/p>\n<p>This can be achieved using the <a href=\"https:\/\/docs.scipy.org\/doc\/numpy\/reference\/generated\/numpy.ones.html\">ones() NumPy function<\/a>.<\/p>\n<p>For example:<\/p>\n<pre class=\"crayon-plain-tag\">...\r\n# generate class labels, -1 for 'real'\r\ny = -ones((n_samples, 1))\r\n...\r\n# create class labels with 1.0 for 'fake'\r\ny = ones((n_samples, 1))<\/pre>\n<\/p>\n<h3>3. Wasserstein Loss Function<\/h3>\n<p>The DCGAN trains the discriminator as a binary classification model to predict the probability that a given image is real.<\/p>\n<p>To train this model, the discriminator is optimized using the binary cross entropy loss function. The same loss function is used to update the generator model.<\/p>\n<p>The primary contribution of the WGAN model is the use of a new loss function that encourages the discriminator to predict a score of how real or fake a given input looks. This transforms the role of the discriminator from a classifier into a critic for scoring the realness or fakeness of images, where the difference between the scores is as large as possible.<\/p>\n<p>We can implement the Wasserstein loss as a custom function in Keras that calculates the average score for real or fake images.<\/p>\n<p>The score is maximizing for real examples and minimizing for fake examples. Given that stochastic gradient descent is a minimization algorithm, we can multiply the class label by the mean score (e.g. -1 for real and 1 for fake which as no effect), which ensures that the loss for real and fake images is minimizing to the network.<\/p>\n<p>An efficient implementation of this loss function for Keras is listed below.<\/p>\n<pre class=\"crayon-plain-tag\">from keras import backend\r\n\r\n# implementation of wasserstein loss\r\ndef wasserstein_loss(y_true, y_pred):\r\n\treturn backend.mean(y_true * y_pred)<\/pre>\n<p>This loss function can be used to train a Keras model by specifying the function name when compiling the model.<\/p>\n<p>For example:<\/p>\n<pre class=\"crayon-plain-tag\">...\r\n# compile the model\r\nmodel.compile(loss=wasserstein_loss, ...)<\/pre>\n<\/p>\n<h3>4. Critic Weight Clipping<\/h3>\n<p>The DCGAN does not use any gradient clipping, although the WGAN requires gradient clipping for the critic model.<\/p>\n<p>We can implement weight clipping as a Keras constraint.<\/p>\n<p>This is a class that must extend the <em>Constraint<\/em> class and define an implementation of the <em>__call__()<\/em> function for applying the operation and the <em>get_config()<\/em> function for returning any configuration.<\/p>\n<p>We can also define an <em>__init__()<\/em> function to set the configuration, in this case, the symmetrical size of the bounding box for the weight hypercube, e.g. 0.01.<\/p>\n<p>The <em>ClipConstraint<\/em> class is defined below.<\/p>\n<pre class=\"crayon-plain-tag\"># clip model weights to a given hypercube\r\nclass ClipConstraint(Constraint):\r\n\t# set clip value when initialized\r\n\tdef __init__(self, clip_value):\r\n\t\tself.clip_value = clip_value\r\n\r\n\t# clip model weights to hypercube\r\n\tdef __call__(self, weights):\r\n\t\treturn backend.clip(weights, -self.clip_value, self.clip_value)\r\n\r\n\t# get the config\r\n\tdef get_config(self):\r\n\t\treturn {'clip_value': self.clip_value}<\/pre>\n<p>To use the constraint, the class can be constructed, then used in a layer by setting the <em>kernel_constraint<\/em> argument; for example:<\/p>\n<pre class=\"crayon-plain-tag\">...\r\n# define the constraint\r\nconst = ClipConstraint(0.01)\r\n...\r\n# use the constraint in a layer\r\nmodel.add(Conv2D(..., kernel_constraint=const))<\/pre>\n<p>The constraint is only required when updating the critic model.<\/p>\n<h3>5. Update Critic More Than Generator<\/h3>\n<p>In the DCGAN, the generator and the discriminator model must be updated in equal amounts.<\/p>\n<p>Specifically, the discriminator is updated with a half batch of real and a half batch of fake samples each iteration, whereas the generator is updated with a single batch of generated samples.<\/p>\n<p>For example:<\/p>\n<pre class=\"crayon-plain-tag\">...\r\n# main gan training loop\r\nfor i in range(n_steps):\r\n\r\n\t# update the discriminator\r\n\r\n\t# get randomly selected 'real' samples\r\n\tX_real, y_real = generate_real_samples(dataset, half_batch)\r\n\t# update critic model weights\r\n\tc_loss1 = c_model.train_on_batch(X_real, y_real)\r\n\t# generate 'fake' examples\r\n\tX_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)\r\n\t# update critic model weights\r\n\tc_loss2 = c_model.train_on_batch(X_fake, y_fake)\r\n\r\n\t# update generator\r\n\r\n\t# prepare points in latent space as input for the generator\r\n\tX_gan = generate_latent_points(latent_dim, n_batch)\r\n\t# create inverted labels for the fake samples\r\n\ty_gan = ones((n_batch, 1))\r\n\t# update the generator via the critic's error\r\n\tg_loss = gan_model.train_on_batch(X_gan, y_gan)<\/pre>\n<p>In the WGAN model, the critic model must be updated more than the generator model.<\/p>\n<p>Specifically, a new hyperparameter is defined to control the number of times that the critic is updated for each update to the generator model, called n_critic, and is set to 5.<\/p>\n<p>This can be implemented as a new loop within the main GAN update loop; for example:<\/p>\n<pre class=\"crayon-plain-tag\">...\r\n# main gan training loop\r\nfor i in range(n_steps):\r\n\r\n\t# update the critic\r\n\tfor _ in range(n_critic):\r\n\t\t# get randomly selected 'real' samples\r\n\t\tX_real, y_real = generate_real_samples(dataset, half_batch)\r\n\t\t# update critic model weights\r\n\t\tc_loss1 = c_model.train_on_batch(X_real, y_real)\r\n\t\t# generate 'fake' examples\r\n\t\tX_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)\r\n\t\t# update critic model weights\r\n\t\tc_loss2 = c_model.train_on_batch(X_fake, y_fake)\r\n\r\n\t# update generator\r\n\r\n\t# prepare points in latent space as input for the generator\r\n\tX_gan = generate_latent_points(latent_dim, n_batch)\r\n\t# create inverted labels for the fake samples\r\n\ty_gan = ones((n_batch, 1))\r\n\t# update the generator via the critic's error\r\n\tg_loss = gan_model.train_on_batch(X_gan, y_gan)<\/pre>\n<\/p>\n<h3>6. Use RMSProp Stochastic Gradient Descent<\/h3>\n<p>The DCGAN uses the <a href=\"https:\/\/machinelearningmastery.com\/adam-optimization-algorithm-for-deep-learning\/\">Adam version of stochastic gradient descent<\/a> with a small learning rate and modest momentum.<\/p>\n<p>The WGAN recommends the use of <a href=\"https:\/\/machinelearningmastery.com\/understand-the-dynamics-of-learning-rate-on-deep-learning-neural-networks\/\">RMSProp instead<\/a>, with a small learning rate of 0.00005.<\/p>\n<p>This can be implemented in Keras when the model is compiled. For example:<\/p>\n<pre class=\"crayon-plain-tag\">...\r\n# compile model\r\nopt = RMSprop(lr=0.00005)\r\nmodel.compile(loss=wasserstein_loss, optimizer=opt)<\/pre>\n<\/p>\n<h2>How to Train a Wasserstein GAN Model<\/h2>\n<p>Now that we know the specific implementation details for the WGAN, we can implement the model for image generation.<\/p>\n<p>In this section, we will develop a WGAN to generate a single handwritten digit (\u20187\u2019) from the <a href=\"https:\/\/machinelearningmastery.com\/how-to-develop-a-cnn-from-scratch-for-fashion-mnist-clothing-classification\/\">MNIST dataset<\/a>. This is a good test problem for the WGAN as it is a small dataset requiring a modest mode that is quick to train.<\/p>\n<p>The first step is to define the models.<\/p>\n<p>The critic model takes as input one 28\u00d728 grayscale image and outputs a score for the realness or fakeness of the image. It is implemented as a modest convolutional neural network using best practices for DCGAN design such as using the <a href=\"https:\/\/machinelearningmastery.com\/rectified-linear-activation-function-for-deep-learning-neural-networks\/\">LeakyReLU activation function<\/a> with a slope of 0.2, <a href=\"https:\/\/machinelearningmastery.com\/how-to-accelerate-learning-of-deep-neural-networks-with-batch-normalization\/\">batch normalization<\/a>, and using a <a href=\"https:\/\/machinelearningmastery.com\/padding-and-stride-for-convolutional-neural-networks\/\">2\u00d72 stride to downsample<\/a>.<\/p>\n<p>The critic model makes use of the new ClipConstraint weight constraint to clip model weights after mini-batch updates and is optimized using the custom <em>wasserstein_loss()<\/em> function, the RMSProp version of stochastic gradient descent with a learning rate of 0.00005.<\/p>\n<p>The <em>define_critic()<\/em> function below implements this, defining and compiling the critic model and returning it. The input shape of the image is parameterized as a default function argument to make it clear.<\/p>\n<pre class=\"crayon-plain-tag\"># define the standalone critic model\r\ndef define_critic(in_shape=(28,28,1)):\r\n\t# weight initialization\r\n\tinit = RandomNormal(stddev=0.02)\r\n\t# weight constraint\r\n\tconst = ClipConstraint(0.01)\r\n\t# define model\r\n\tmodel = Sequential()\r\n\t# downsample to 14x14\r\n\tmodel.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))\r\n\tmodel.add(BatchNormalization())\r\n\tmodel.add(LeakyReLU(alpha=0.2))\r\n\t# downsample to 7x7\r\n\tmodel.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))\r\n\tmodel.add(BatchNormalization())\r\n\tmodel.add(LeakyReLU(alpha=0.2))\r\n\t# scoring, linear activation\r\n\tmodel.add(Flatten())\r\n\tmodel.add(Dense(1))\r\n\t# compile model\r\n\topt = RMSprop(lr=0.00005)\r\n\tmodel.compile(loss=wasserstein_loss, optimizer=opt)\r\n\treturn model<\/pre>\n<p>The generator model takes as input a point in the latent space and outputs a single 28\u00d728 grayscale image.<\/p>\n<p>This is achieved by using a fully connected layer to interpret the point in the latent space and provide sufficient activations that can be reshaped into many copies (in this case, 128) of a low-resolution version of the output image (e.g. 7\u00d77). This is then upsampled two times, doubling the size and quadrupling the area of the activations each time using transpose convolutional layers.<\/p>\n<p>The model uses best practices such as the LeakyReLU activation, a kernel size that is a factor of the stride size, and a hyperbolic tangent (tanh) activation function in the output layer.<\/p>\n<p>The <em>define_generator()<\/em> function below defines the generator model but intentionally does not compile it as it is not trained directly, then returns the model. The size of the latent space is parameterized as a function argument.<\/p>\n<pre class=\"crayon-plain-tag\"># define the standalone generator model\r\ndef define_generator(latent_dim):\r\n\t# weight initialization\r\n\tinit = RandomNormal(stddev=0.02)\r\n\t# define model\r\n\tmodel = Sequential()\r\n\t# foundation for 7x7 image\r\n\tn_nodes = 128 * 7 * 7\r\n\tmodel.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))\r\n\tmodel.add(LeakyReLU(alpha=0.2))\r\n\tmodel.add(Reshape((7, 7, 128)))\r\n\t# upsample to 14x14\r\n\tmodel.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))\r\n\tmodel.add(BatchNormalization())\r\n\tmodel.add(LeakyReLU(alpha=0.2))\r\n\t# upsample to 28x28\r\n\tmodel.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))\r\n\tmodel.add(BatchNormalization())\r\n\tmodel.add(LeakyReLU(alpha=0.2))\r\n\t# output 28x28x1\r\n\tmodel.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))\r\n\treturn model<\/pre>\n<p>Next, a GAN model can be defined that combines both the generator model and the critic model into one larger model.<\/p>\n<p>This larger model will be used to train the model weights in the generator, using the output and error calculated by the critic model. The critic model is trained separately, and as such, the model weights are marked as not trainable in this larger GAN model to ensure that only the weights of the generator model are updated. This change to the trainability of the critic weights only has an effect when training the combined GAN model, not when training the critic standalone.<\/p>\n<p>This larger GAN model takes as input a point in the latent space, uses the generator model to generate an image, which is fed as input to the critic model, then output scored as real or fake. The model is fit using RMSProp with the custom <em>wasserstein_loss()<\/em> function.<\/p>\n<p>The <em>define_gan()<\/em> function below implements this, taking the already defined generator and critic models as input.<\/p>\n<pre class=\"crayon-plain-tag\"># define the combined generator and critic model, for updating the generator\r\ndef define_gan(generator, critic):\r\n\t# make weights in the critic not trainable\r\n\tcritic.trainable = False\r\n\t# connect them\r\n\tmodel = Sequential()\r\n\t# add generator\r\n\tmodel.add(generator)\r\n\t# add the critic\r\n\tmodel.add(critic)\r\n\t# compile model\r\n\topt = RMSprop(lr=0.00005)\r\n\tmodel.compile(loss=wasserstein_loss, optimizer=opt)\r\n\treturn model<\/pre>\n<p>Now that we have defined the GAN model, we need to train it. But, before we can train the model, we require input data.<\/p>\n<p>The first step is to load and <a href=\"https:\/\/machinelearningmastery.com\/how-to-manually-scale-image-pixel-data-for-deep-learning\/\">scale the MNIST dataset<\/a>. The whole dataset is loaded via a call to the <em>load_data()<\/em> Keras function, then a subset of the images is selected (about 5,000) that belongs to class 7, e.g. are a handwritten depiction of the number seven. Then the pixel values must be scaled to the range [-1,1] to match the output of the generator model.<\/p>\n<p>The <em>load_real_samples()<\/em> function below implements this, returning the loaded and scaled subset of the MNIST training dataset ready for modeling.<\/p>\n<pre class=\"crayon-plain-tag\"># load images\r\ndef load_real_samples():\r\n\t# load dataset\r\n\t(trainX, trainy), (_, _) = load_data()\r\n\t# select all of the examples for a given class\r\n\tselected_ix = trainy == 7\r\n\tX = trainX[selected_ix]\r\n\t# expand to 3d, e.g. add channels\r\n\tX = expand_dims(X, axis=-1)\r\n\t# convert from ints to floats\r\n\tX = X.astype('float32')\r\n\t# scale from [0,255] to [-1,1]\r\n\tX = (X - 127.5) \/ 127.5\r\n\treturn X<\/pre>\n<p>We will require one batch (or a half) batch of real images from the dataset each update to the GAN model. A simple way to achieve this is to select a <a href=\"https:\/\/machinelearningmastery.com\/how-to-generate-random-numbers-in-python\/\">random sample<\/a> of images from the dataset each time.<\/p>\n<p>The <em>generate_real_samples()<\/em> function below implements this, taking the prepared dataset as an argument, selecting and returning a random sample of images and their corresponding label for the critic, specifically target=-1 indicating that they are real images.<\/p>\n<pre class=\"crayon-plain-tag\"># select real samples\r\ndef generate_real_samples(dataset, n_samples):\r\n\t# choose random instances\r\n\tix = randint(0, dataset.shape[0], n_samples)\r\n\t# select images\r\n\tX = dataset[ix]\r\n\t# generate class labels, -1 for 'real'\r\n\ty = -ones((n_samples, 1))\r\n\treturn X, y<\/pre>\n<p>Next, we need inputs for the generator model. These are random points from the latent space, specifically <a href=\"https:\/\/machinelearningmastery.com\/how-to-generate-random-numbers-in-python\/\">Gaussian distributed random variables<\/a>.<\/p>\n<p>The <em>generate_latent_points()<\/em> function implements this, taking the size of the latent space as an argument and the number of points required, and returning them as a batch of input samples for the generator model.<\/p>\n<pre class=\"crayon-plain-tag\"># generate points in latent space as input for the generator\r\ndef generate_latent_points(latent_dim, n_samples):\r\n\t# generate points in the latent space\r\n\tx_input = randn(latent_dim * n_samples)\r\n\t# reshape into a batch of inputs for the network\r\n\tx_input = x_input.reshape(n_samples, latent_dim)\r\n\treturn x_input<\/pre>\n<p>Next, we need to use the points in the latent space as input to the generator in order to generate new images.<\/p>\n<p>The <em>generate_fake_samples()<\/em> function below implements this, taking the generator model and size of the latent space as arguments, then generating points in the latent space and using them as input to the generator model.<\/p>\n<p>The function returns the generated images and their corresponding label for the critic model, specifically target=1 to indicate they are fake or generated.<\/p>\n<pre class=\"crayon-plain-tag\"># use the generator to generate n fake examples, with class labels\r\ndef generate_fake_samples(generator, latent_dim, n_samples):\r\n\t# generate points in latent space\r\n\tx_input = generate_latent_points(latent_dim, n_samples)\r\n\t# predict outputs\r\n\tX = generator.predict(x_input)\r\n\t# create class labels with 1.0 for 'fake'\r\n\ty = ones((n_samples, 1))\r\n\treturn X, y<\/pre>\n<p>We need to record the performance of the model. Perhaps the most reliable way to evaluate the performance of a GAN is to use the generator to generate images, and then review and subjectively evaluate them.<\/p>\n<p>The <em>summarize_performance()<\/em> function below takes the generator model at a given point during training and uses it to generate 100 images in a 10\u00d710 grid, that are then plotted and saved to file. The model is also saved to file at this time, in case we would like to use it later to generate more images.<\/p>\n<pre class=\"crayon-plain-tag\"># generate samples and save as a plot and save the model\r\ndef summarize_performance(step, g_model, latent_dim, n_samples=100):\r\n\t# prepare fake examples\r\n\tX, _ = generate_fake_samples(g_model, latent_dim, n_samples)\r\n\t# scale from [-1,1] to [0,1]\r\n\tX = (X + 1) \/ 2.0\r\n\t# plot images\r\n\tfor i in range(10 * 10):\r\n\t\t# define subplot\r\n\t\tpyplot.subplot(10, 10, 1 + i)\r\n\t\t# turn off axis\r\n\t\tpyplot.axis('off')\r\n\t\t# plot raw pixel data\r\n\t\tpyplot.imshow(X[i, :, :, 0], cmap='gray_r')\r\n\t# save plot to file\r\n\tfilename1 = 'generated_plot_%04d.png' % (step+1)\r\n\tpyplot.savefig(filename1)\r\n\tpyplot.close()\r\n\t# save the generator model\r\n\tfilename2 = 'model_%04d.h5' % (step+1)\r\n\tg_model.save(filename2)\r\n\tprint('>Saved: %s and %s' % (filename1, filename2))<\/pre>\n<p>In addition to image quality, it is a good idea to keep track of the loss and accuracy of the model over time.<\/p>\n<p>The loss for the critic for real and fake samples can be tracked for each model update, as can the loss for the generator for each update. These can then be used to create line plots of loss at the end of the training run. The <em>plot_history()<\/em> function below implements this and saves the results to file.<\/p>\n<pre class=\"crayon-plain-tag\"># create a line plot of loss for the gan and save to file\r\ndef plot_history(d1_hist, d2_hist, g_hist):\r\n\t# plot history\r\n\tpyplot.plot(d1_hist, label='crit_real')\r\n\tpyplot.plot(d2_hist, label='crit_fake')\r\n\tpyplot.plot(g_hist, label='gen')\r\n\tpyplot.legend()\r\n\tpyplot.savefig('plot_line_plot_loss.png')\r\n\tpyplot.close()<\/pre>\n<p>We are now ready to fit the GAN model.<\/p>\n<p>The model is fit for 10 training epochs, which is arbitrary, as the model begins generating plausible number-7 digits after perhaps the first few epochs. A batch size of 64 samples is used, and each training epoch involves 6,265\/64, or about 97, batches of real and fake samples and updates to the model. The model is therefore trained for 10 epochs of 97 batches, or 970 iterations.<\/p>\n<p>First, the critic model is updated for a half batch of real samples, then a half batch of fake samples, together forming one batch of weight updates. This is then repeated <em>n_critic<\/em> (5) times as required by the WGAN algorithm.<\/p>\n<p>The generator is then updated via the composite GAN model. Importantly, the target label is set to -1 or real for the generated samples. This has the effect of updating the generator toward getting better at generating real samples on the next batch.<\/p>\n<p>The <em>train()<\/em> function below implements this, taking the defined models, dataset, and size of the latent dimension as arguments and parameterizing the number of epochs and batch size with default arguments. The generator model is saved at the end of training.<\/p>\n<p>The performance of the critic and generator models is reported each iteration. Sample images are generated and saved every epoch, and line plots of model performance are created and saved at the end of the run.<\/p>\n<pre class=\"crayon-plain-tag\"># train the generator and critic\r\ndef train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=64, n_critic=5):\r\n\t# calculate the number of batches per training epoch\r\n\tbat_per_epo = int(dataset.shape[0] \/ n_batch)\r\n\t# calculate the number of training iterations\r\n\tn_steps = bat_per_epo * n_epochs\r\n\t# calculate the size of half a batch of samples\r\n\thalf_batch = int(n_batch \/ 2)\r\n\t# lists for keeping track of loss\r\n\tc1_hist, c2_hist, g_hist = list(), list(), list()\r\n\t# manually enumerate epochs\r\n\tfor i in range(n_steps):\r\n\t\t# update the critic more than the generator\r\n\t\tc1_tmp, c2_tmp = list(), list()\r\n\t\tfor _ in range(n_critic):\r\n\t\t\t# get randomly selected 'real' samples\r\n\t\t\tX_real, y_real = generate_real_samples(dataset, half_batch)\r\n\t\t\t# update critic model weights\r\n\t\t\tc_loss1 = c_model.train_on_batch(X_real, y_real)\r\n\t\t\tc1_tmp.append(c_loss1)\r\n\t\t\t# generate 'fake' examples\r\n\t\t\tX_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)\r\n\t\t\t# update critic model weights\r\n\t\t\tc_loss2 = c_model.train_on_batch(X_fake, y_fake)\r\n\t\t\tc2_tmp.append(c_loss2)\r\n\t\t# store critic loss\r\n\t\tc1_hist.append(mean(c1_tmp))\r\n\t\tc2_hist.append(mean(c2_tmp))\r\n\t\t# prepare points in latent space as input for the generator\r\n\t\tX_gan = generate_latent_points(latent_dim, n_batch)\r\n\t\t# create inverted labels for the fake samples\r\n\t\ty_gan = -ones((n_batch, 1))\r\n\t\t# update the generator via the critic's error\r\n\t\tg_loss = gan_model.train_on_batch(X_gan, y_gan)\r\n\t\tg_hist.append(g_loss)\r\n\t\t# summarize loss on this batch\r\n\t\tprint('>%d, c1=%.3f, c2=%.3f g=%.3f' % (i+1, c1_hist[-1], c2_hist[-1], g_loss))\r\n\t\t# evaluate the model performance every 'epoch'\r\n\t\tif (i+1) % bat_per_epo == 0:\r\n\t\t\tsummarize_performance(i, g_model, latent_dim)\r\n\t# line plots of loss\r\n\tplot_history(c1_hist, c2_hist, g_hist)<\/pre>\n<p>Now that all of the functions have been defined, we can create the models, load the dataset, and begin the training process.<\/p>\n<pre class=\"crayon-plain-tag\"># size of the latent space\r\nlatent_dim = 50\r\n# create the critic\r\ncritic = define_critic()\r\n# create the generator\r\ngenerator = define_generator(latent_dim)\r\n# create the gan\r\ngan_model = define_gan(generator, critic)\r\n# load image data\r\ndataset = load_real_samples()\r\nprint(dataset.shape)\r\n# train model\r\ntrain(generator, critic, gan_model, dataset, latent_dim)<\/pre>\n<p>Tying all of this together, the complete example is listed below.<\/p>\n<pre class=\"crayon-plain-tag\"># example of a wgan for generating handwritten digits\r\nfrom numpy import expand_dims\r\nfrom numpy import mean\r\nfrom numpy import ones\r\nfrom numpy.random import randn\r\nfrom numpy.random import randint\r\nfrom keras.datasets.mnist import load_data\r\nfrom keras import backend\r\nfrom keras.optimizers import RMSprop\r\nfrom keras.models import Sequential\r\nfrom keras.layers import Dense\r\nfrom keras.layers import Reshape\r\nfrom keras.layers import Flatten\r\nfrom keras.layers import Conv2D\r\nfrom keras.layers import Conv2DTranspose\r\nfrom keras.layers import LeakyReLU\r\nfrom keras.layers import BatchNormalization\r\nfrom keras.initializers import RandomNormal\r\nfrom keras.constraints import Constraint\r\nfrom matplotlib import pyplot\r\n\r\n# clip model weights to a given hypercube\r\nclass ClipConstraint(Constraint):\r\n\t# set clip value when initialized\r\n\tdef __init__(self, clip_value):\r\n\t\tself.clip_value = clip_value\r\n\r\n\t# clip model weights to hypercube\r\n\tdef __call__(self, weights):\r\n\t\treturn backend.clip(weights, -self.clip_value, self.clip_value)\r\n\r\n\t# get the config\r\n\tdef get_config(self):\r\n\t\treturn {'clip_value': self.clip_value}\r\n\r\n# calculate wasserstein loss\r\ndef wasserstein_loss(y_true, y_pred):\r\n\treturn backend.mean(y_true * y_pred)\r\n\r\n# define the standalone critic model\r\ndef define_critic(in_shape=(28,28,1)):\r\n\t# weight initialization\r\n\tinit = RandomNormal(stddev=0.02)\r\n\t# weight constraint\r\n\tconst = ClipConstraint(0.01)\r\n\t# define model\r\n\tmodel = Sequential()\r\n\t# downsample to 14x14\r\n\tmodel.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))\r\n\tmodel.add(BatchNormalization())\r\n\tmodel.add(LeakyReLU(alpha=0.2))\r\n\t# downsample to 7x7\r\n\tmodel.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))\r\n\tmodel.add(BatchNormalization())\r\n\tmodel.add(LeakyReLU(alpha=0.2))\r\n\t# scoring, linear activation\r\n\tmodel.add(Flatten())\r\n\tmodel.add(Dense(1))\r\n\t# compile model\r\n\topt = RMSprop(lr=0.00005)\r\n\tmodel.compile(loss=wasserstein_loss, optimizer=opt)\r\n\treturn model\r\n\r\n# define the standalone generator model\r\ndef define_generator(latent_dim):\r\n\t# weight initialization\r\n\tinit = RandomNormal(stddev=0.02)\r\n\t# define model\r\n\tmodel = Sequential()\r\n\t# foundation for 7x7 image\r\n\tn_nodes = 128 * 7 * 7\r\n\tmodel.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))\r\n\tmodel.add(LeakyReLU(alpha=0.2))\r\n\tmodel.add(Reshape((7, 7, 128)))\r\n\t# upsample to 14x14\r\n\tmodel.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))\r\n\tmodel.add(BatchNormalization())\r\n\tmodel.add(LeakyReLU(alpha=0.2))\r\n\t# upsample to 28x28\r\n\tmodel.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))\r\n\tmodel.add(BatchNormalization())\r\n\tmodel.add(LeakyReLU(alpha=0.2))\r\n\t# output 28x28x1\r\n\tmodel.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))\r\n\treturn model\r\n\r\n# define the combined generator and critic model, for updating the generator\r\ndef define_gan(generator, critic):\r\n\t# make weights in the critic not trainable\r\n\tcritic.trainable = False\r\n\t# connect them\r\n\tmodel = Sequential()\r\n\t# add generator\r\n\tmodel.add(generator)\r\n\t# add the critic\r\n\tmodel.add(critic)\r\n\t# compile model\r\n\topt = RMSprop(lr=0.00005)\r\n\tmodel.compile(loss=wasserstein_loss, optimizer=opt)\r\n\treturn model\r\n\r\n# load images\r\ndef load_real_samples():\r\n\t# load dataset\r\n\t(trainX, trainy), (_, _) = load_data()\r\n\t# select all of the examples for a given class\r\n\tselected_ix = trainy == 7\r\n\tX = trainX[selected_ix]\r\n\t# expand to 3d, e.g. add channels\r\n\tX = expand_dims(X, axis=-1)\r\n\t# convert from ints to floats\r\n\tX = X.astype('float32')\r\n\t# scale from [0,255] to [-1,1]\r\n\tX = (X - 127.5) \/ 127.5\r\n\treturn X\r\n\r\n# select real samples\r\ndef generate_real_samples(dataset, n_samples):\r\n\t# choose random instances\r\n\tix = randint(0, dataset.shape[0], n_samples)\r\n\t# select images\r\n\tX = dataset[ix]\r\n\t# generate class labels, -1 for 'real'\r\n\ty = -ones((n_samples, 1))\r\n\treturn X, y\r\n\r\n# generate points in latent space as input for the generator\r\ndef generate_latent_points(latent_dim, n_samples):\r\n\t# generate points in the latent space\r\n\tx_input = randn(latent_dim * n_samples)\r\n\t# reshape into a batch of inputs for the network\r\n\tx_input = x_input.reshape(n_samples, latent_dim)\r\n\treturn x_input\r\n\r\n# use the generator to generate n fake examples, with class labels\r\ndef generate_fake_samples(generator, latent_dim, n_samples):\r\n\t# generate points in latent space\r\n\tx_input = generate_latent_points(latent_dim, n_samples)\r\n\t# predict outputs\r\n\tX = generator.predict(x_input)\r\n\t# create class labels with 1.0 for 'fake'\r\n\ty = ones((n_samples, 1))\r\n\treturn X, y\r\n\r\n# generate samples and save as a plot and save the model\r\ndef summarize_performance(step, g_model, latent_dim, n_samples=100):\r\n\t# prepare fake examples\r\n\tX, _ = generate_fake_samples(g_model, latent_dim, n_samples)\r\n\t# scale from [-1,1] to [0,1]\r\n\tX = (X + 1) \/ 2.0\r\n\t# plot images\r\n\tfor i in range(10 * 10):\r\n\t\t# define subplot\r\n\t\tpyplot.subplot(10, 10, 1 + i)\r\n\t\t# turn off axis\r\n\t\tpyplot.axis('off')\r\n\t\t# plot raw pixel data\r\n\t\tpyplot.imshow(X[i, :, :, 0], cmap='gray_r')\r\n\t# save plot to file\r\n\tfilename1 = 'generated_plot_%04d.png' % (step+1)\r\n\tpyplot.savefig(filename1)\r\n\tpyplot.close()\r\n\t# save the generator model\r\n\tfilename2 = 'model_%04d.h5' % (step+1)\r\n\tg_model.save(filename2)\r\n\tprint('>Saved: %s and %s' % (filename1, filename2))\r\n\r\n# create a line plot of loss for the gan and save to file\r\ndef plot_history(d1_hist, d2_hist, g_hist):\r\n\t# plot history\r\n\tpyplot.plot(d1_hist, label='crit_real')\r\n\tpyplot.plot(d2_hist, label='crit_fake')\r\n\tpyplot.plot(g_hist, label='gen')\r\n\tpyplot.legend()\r\n\tpyplot.savefig('plot_line_plot_loss.png')\r\n\tpyplot.close()\r\n\r\n# train the generator and critic\r\ndef train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=64, n_critic=5):\r\n\t# calculate the number of batches per training epoch\r\n\tbat_per_epo = int(dataset.shape[0] \/ n_batch)\r\n\t# calculate the number of training iterations\r\n\tn_steps = bat_per_epo * n_epochs\r\n\t# calculate the size of half a batch of samples\r\n\thalf_batch = int(n_batch \/ 2)\r\n\t# lists for keeping track of loss\r\n\tc1_hist, c2_hist, g_hist = list(), list(), list()\r\n\t# manually enumerate epochs\r\n\tfor i in range(n_steps):\r\n\t\t# update the critic more than the generator\r\n\t\tc1_tmp, c2_tmp = list(), list()\r\n\t\tfor _ in range(n_critic):\r\n\t\t\t# get randomly selected 'real' samples\r\n\t\t\tX_real, y_real = generate_real_samples(dataset, half_batch)\r\n\t\t\t# update critic model weights\r\n\t\t\tc_loss1 = c_model.train_on_batch(X_real, y_real)\r\n\t\t\tc1_tmp.append(c_loss1)\r\n\t\t\t# generate 'fake' examples\r\n\t\t\tX_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)\r\n\t\t\t# update critic model weights\r\n\t\t\tc_loss2 = c_model.train_on_batch(X_fake, y_fake)\r\n\t\t\tc2_tmp.append(c_loss2)\r\n\t\t# store critic loss\r\n\t\tc1_hist.append(mean(c1_tmp))\r\n\t\tc2_hist.append(mean(c2_tmp))\r\n\t\t# prepare points in latent space as input for the generator\r\n\t\tX_gan = generate_latent_points(latent_dim, n_batch)\r\n\t\t# create inverted labels for the fake samples\r\n\t\ty_gan = -ones((n_batch, 1))\r\n\t\t# update the generator via the critic's error\r\n\t\tg_loss = gan_model.train_on_batch(X_gan, y_gan)\r\n\t\tg_hist.append(g_loss)\r\n\t\t# summarize loss on this batch\r\n\t\tprint('>%d, c1=%.3f, c2=%.3f g=%.3f' % (i+1, c1_hist[-1], c2_hist[-1], g_loss))\r\n\t\t# evaluate the model performance every 'epoch'\r\n\t\tif (i+1) % bat_per_epo == 0:\r\n\t\t\tsummarize_performance(i, g_model, latent_dim)\r\n\t# line plots of loss\r\n\tplot_history(c1_hist, c2_hist, g_hist)\r\n\r\n# size of the latent space\r\nlatent_dim = 50\r\n# create the critic\r\ncritic = define_critic()\r\n# create the generator\r\ngenerator = define_generator(latent_dim)\r\n# create the gan\r\ngan_model = define_gan(generator, critic)\r\n# load image data\r\ndataset = load_real_samples()\r\nprint(dataset.shape)\r\n# train model\r\ntrain(generator, critic, gan_model, dataset, latent_dim)<\/pre>\n<p>Running the example is quick, taking approximately 10 minutes on modern hardware without a GPU.<\/p>\n<p>Your specific results will vary given the stochastic nature of the learning algorithm. Nevertheless, the general structure of training should be very similar.<\/p>\n<p>First, the loss of the critic and generator models is reported to the console each iteration of the training loop. Specifically, c1 is the loss of the critic on real examples, c2 is the loss of the critic in generated samples, and g is the loss of the generator trained via the critic.<\/p>\n<p>The c1 scores are inverted as part of the loss function; this means if they are reported as negative, then they are really positive, and if they are reported as positive, they are really negative. The sign of the c2 scores is unchanged.<\/p>\n<p>Recall that the Wasserstein loss seeks scores for real and fake that are more different during training. We can see this towards the end of the run, such as the final epoch where the <em>c1<\/em> loss for real examples is 5.338 (really -5.338) and the <em>c2<\/em> loss for fake examples is -14.260, and this separation of about 10 units is consistent at least for the prior few iterations.<\/p>\n<p>We can also see that in this case, the model is scoring the loss of the generator at around 20. Again, recall that we update the generator via the critic model and treat the generated examples as real with the target of -1, therefore the score can be interpreted as a value around -20, close to the loss for fake samples.<\/p>\n<pre class=\"crayon-plain-tag\">...\r\n>961, c1=5.110, c2=-15.388 g=19.579\r\n>962, c1=6.116, c2=-15.222 g=20.054\r\n>963, c1=4.982, c2=-15.192 g=21.048\r\n>964, c1=4.238, c2=-14.689 g=23.911\r\n>965, c1=5.585, c2=-14.126 g=19.578\r\n>966, c1=4.807, c2=-14.755 g=20.034\r\n>967, c1=6.307, c2=-16.538 g=19.572\r\n>968, c1=4.298, c2=-14.178 g=17.956\r\n>969, c1=4.283, c2=-13.398 g=17.326\r\n>970, c1=5.338, c2=-14.260 g=19.927<\/pre>\n<p>Line plots for loss are created and saved at the end of the run.<\/p>\n<p>The plot shows the loss for the critic on real samples (blue), the loss for the critic on fake samples (orange), and the loss for the critic when updating the generator with fake samples (green).<\/p>\n<p>There is one important factor when reviewing <a href=\"https:\/\/machinelearningmastery.com\/learning-curves-for-diagnosing-machine-learning-model-performance\/\">learning curves<\/a> for the WGAN and that is the trend.<\/p>\n<p>The benefit of the WGAN is that the loss correlates with generated image quality. Lower loss means better quality images, for a stable training process.<\/p>\n<p>In this case, lower loss specifically refers to lower Wasserstein loss for generated images as reported by the critic (orange line). This sign of this loss is not inverted by the target label (e.g. the target label is +1.0), therefore, a well-performing WGAN should show this line trending down as the image quality of the generated model is increased.<\/p>\n<div id=\"attachment_8223\" style=\"width: 650px\" class=\"wp-caption aligncenter\"><img loading=\"lazy\" decoding=\"async\" aria-describedby=\"caption-attachment-8223\" class=\"size-full wp-image-8223\" src=\"https:\/\/machinelearningmastery.com\/wp-content\/uploads\/2019\/05\/Line-Plots-of-Loss-and-Accuracy-for-a-Wasserstein-Generative-Adversarial-Network.png\" alt=\"Line Plots of Loss and Accuracy for a Wasserstein Generative Adversarial Network\" width=\"640\" height=\"480\" srcset=\"http:\/\/3qeqpr26caki16dnhd19sv6by6v.wpengine.netdna-cdn.com\/wp-content\/uploads\/2019\/05\/Line-Plots-of-Loss-and-Accuracy-for-a-Wasserstein-Generative-Adversarial-Network.png 640w, http:\/\/3qeqpr26caki16dnhd19sv6by6v.wpengine.netdna-cdn.com\/wp-content\/uploads\/2019\/05\/Line-Plots-of-Loss-and-Accuracy-for-a-Wasserstein-Generative-Adversarial-Network-300x225.png 300w\" sizes=\"(max-width: 640px) 100vw, 640px\"><\/p>\n<p id=\"caption-attachment-8223\" class=\"wp-caption-text\">Line Plots of Loss and Accuracy for a Wasserstein Generative Adversarial Network<\/p>\n<\/div>\n<p>In this case, more training seems to result in better quality generated images, with a major hurdle occurring around epoch 200-300 after which quality remains pretty good for the model.<\/p>\n<p>Before and around this hurdle, image quality is poor; for example:<\/p>\n<div id=\"attachment_8224\" style=\"width: 650px\" class=\"wp-caption aligncenter\"><img loading=\"lazy\" decoding=\"async\" aria-describedby=\"caption-attachment-8224\" class=\"size-full wp-image-8224\" src=\"https:\/\/machinelearningmastery.com\/wp-content\/uploads\/2019\/05\/Sample-of-100-Generated-Images-of-a-Handwritten-Number-7-at-Epoch-97-from-a-Wasserstein-GAN.png\" alt=\"Sample of 100 Generated Images of a Handwritten Number 7 at Epoch 97 from a Wasserstein GAN.\" width=\"640\" height=\"480\" srcset=\"http:\/\/3qeqpr26caki16dnhd19sv6by6v.wpengine.netdna-cdn.com\/wp-content\/uploads\/2019\/05\/Sample-of-100-Generated-Images-of-a-Handwritten-Number-7-at-Epoch-97-from-a-Wasserstein-GAN.png 640w, http:\/\/3qeqpr26caki16dnhd19sv6by6v.wpengine.netdna-cdn.com\/wp-content\/uploads\/2019\/05\/Sample-of-100-Generated-Images-of-a-Handwritten-Number-7-at-Epoch-97-from-a-Wasserstein-GAN-300x225.png 300w\" sizes=\"(max-width: 640px) 100vw, 640px\"><\/p>\n<p id=\"caption-attachment-8224\" class=\"wp-caption-text\">Sample of 100 Generated Images of a Handwritten Number 7 at Epoch 97 from a Wasserstein GAN.<\/p>\n<\/div>\n<p>After this epoch, the WGAN continues to generate plausible handwritten digits.<\/p>\n<div id=\"attachment_8225\" style=\"width: 650px\" class=\"wp-caption aligncenter\"><img loading=\"lazy\" decoding=\"async\" aria-describedby=\"caption-attachment-8225\" class=\"size-full wp-image-8225\" src=\"https:\/\/machinelearningmastery.com\/wp-content\/uploads\/2019\/05\/Sample-of-100-Generated-Images-of-a-Handwritten-Number-7-at-Epoch-970-from-a-Wasserstein-GAN.png\" alt=\"Sample of 100 Generated Images of a Handwritten Number 7 at Epoch 970 from a Wasserstein GAN.\" width=\"640\" height=\"480\" srcset=\"http:\/\/3qeqpr26caki16dnhd19sv6by6v.wpengine.netdna-cdn.com\/wp-content\/uploads\/2019\/05\/Sample-of-100-Generated-Images-of-a-Handwritten-Number-7-at-Epoch-970-from-a-Wasserstein-GAN.png 640w, http:\/\/3qeqpr26caki16dnhd19sv6by6v.wpengine.netdna-cdn.com\/wp-content\/uploads\/2019\/05\/Sample-of-100-Generated-Images-of-a-Handwritten-Number-7-at-Epoch-970-from-a-Wasserstein-GAN-300x225.png 300w\" sizes=\"(max-width: 640px) 100vw, 640px\"><\/p>\n<p id=\"caption-attachment-8225\" class=\"wp-caption-text\">Sample of 100 Generated Images of a Handwritten Number 7 at Epoch 970 from a Wasserstein GAN.<\/p>\n<\/div>\n<h2>Further Reading<\/h2>\n<p>This section provides more resources on the topic if you are looking to go deeper.<\/p>\n<h3>Papers<\/h3>\n<ul>\n<li><a href=\"https:\/\/arxiv.org\/abs\/1701.07875\">Wasserstein GAN<\/a>, 2017.<\/li>\n<li><a href=\"https:\/\/arxiv.org\/abs\/1704.00028\">Improved Training of Wasserstein GANs<\/a>, 2017.<\/li>\n<\/ul>\n<h3>API<\/h3>\n<ul>\n<li><a href=\"https:\/\/keras.io\/datasets\/\">Keras Datasets API<\/a>.<\/li>\n<li><a href=\"https:\/\/keras.io\/models\/sequential\/\">Keras Sequential Model API<\/a><\/li>\n<li><a href=\"https:\/\/keras.io\/layers\/convolutional\/\">Keras Convolutional Layers API<\/a><\/li>\n<li><a href=\"https:\/\/keras.io\/getting-started\/faq\/#how-can-i-freeze-keras-layers\">How can I \u201cfreeze\u201d Keras layers?<\/a><\/li>\n<li><a href=\"https:\/\/matplotlib.org\/api\/\">MatplotLib API<\/a><\/li>\n<li><a href=\"https:\/\/docs.scipy.org\/doc\/numpy\/reference\/routines.random.html\">NumPy Random sampling (numpy.random) API<\/a><\/li>\n<li><a href=\"https:\/\/docs.scipy.org\/doc\/numpy\/reference\/routines.array-manipulation.html\">NumPy Array manipulation routines<\/a><\/li>\n<\/ul>\n<h3>Articles<\/h3>\n<ul>\n<li><a href=\"https:\/\/github.com\/martinarjovsky\/WassersteinGAN\">WassersteinGAN, GitHub<\/a>.<\/li>\n<li><a href=\"https:\/\/github.com\/kpandey008\/wasserstein-gans\">Wasserstein Generative Adversarial Networks (WGANS) Project, GitHub<\/a>.<\/li>\n<li><a href=\"https:\/\/github.com\/eriklindernoren\/Keras-GAN\">Keras-GAN: Keras implementations of Generative Adversarial Networks, GitHub<\/a>.<\/li>\n<li><a href=\"https:\/\/github.com\/keras-team\/keras-contrib\/blob\/master\/examples\/improved_wgan.py\">Improved WGAN, keras-contrib Project, GitHub.<\/a><\/li>\n<\/ul>\n<h2>Summary<\/h2>\n<p>In this tutorial, you discovered how to implement the Wasserstein generative adversarial network from scratch.<\/p>\n<p>Specifically, you learned:<\/p>\n<ul>\n<li>The differences between the standard deep convolutional GAN and the new Wasserstein GAN.<\/li>\n<li>How to implement the specific details of the Wasserstein GAN from scratch.<\/li>\n<li>How to develop a WGAN for image generation and interpret the dynamic behavior of the model.<\/li>\n<\/ul>\n<p>Do you have any questions?<br \/>\nAsk your questions in the comments below and I will do my best to answer.<\/p>\n<p>The post <a rel=\"nofollow\" href=\"https:\/\/machinelearningmastery.com\/how-to-code-a-wasserstein-generative-adversarial-network-wgan-from-scratch\/\">How to Develop a Wasserstein Generative Adversarial Network (WGAN) From Scratch<\/a> appeared first on <a rel=\"nofollow\" href=\"https:\/\/machinelearningmastery.com\/\">Machine Learning Mastery<\/a>.<\/p>\n<\/div>\n<p><a href=\"https:\/\/machinelearningmastery.com\/how-to-code-a-wasserstein-generative-adversarial-network-wgan-from-scratch\/\">Go to Source<\/a><\/p>\n","protected":false},"excerpt":{"rendered":"<p>Author: Jason Brownlee The Wasserstein Generative Adversarial Network, or Wasserstein GAN, is an extension to the generative adversarial network that both improves the stability when [&hellip;] <span class=\"read-more-link\"><a class=\"read-more\" href=\"https:\/\/www.aiproblog.com\/index.php\/2019\/07\/16\/how-to-develop-a-wasserstein-generative-adversarial-network-wgan-from-scratch\/\">Read More<\/a><\/span><\/p>\n","protected":false},"author":1,"featured_media":2371,"comment_status":"open","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"_bbp_topic_count":0,"_bbp_reply_count":0,"_bbp_total_topic_count":0,"_bbp_total_reply_count":0,"_bbp_voice_count":0,"_bbp_anonymous_reply_count":0,"_bbp_topic_count_hidden":0,"_bbp_reply_count_hidden":0,"_bbp_forum_subforum_count":0,"footnotes":""},"categories":[24],"tags":[],"_links":{"self":[{"href":"https:\/\/www.aiproblog.com\/index.php\/wp-json\/wp\/v2\/posts\/2370"}],"collection":[{"href":"https:\/\/www.aiproblog.com\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.aiproblog.com\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.aiproblog.com\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.aiproblog.com\/index.php\/wp-json\/wp\/v2\/comments?post=2370"}],"version-history":[{"count":0,"href":"https:\/\/www.aiproblog.com\/index.php\/wp-json\/wp\/v2\/posts\/2370\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/www.aiproblog.com\/index.php\/wp-json\/wp\/v2\/media\/2371"}],"wp:attachment":[{"href":"https:\/\/www.aiproblog.com\/index.php\/wp-json\/wp\/v2\/media?parent=2370"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.aiproblog.com\/index.php\/wp-json\/wp\/v2\/categories?post=2370"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.aiproblog.com\/index.php\/wp-json\/wp\/v2\/tags?post=2370"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}