An introduction to entropy, cross entropy and KL divergence in machine learning

If you’ve been involved with neural networks and have beeen using them for classification, you almost certainly will have used a cross entropy loss function. However, have you really understood what cross-entropy means? Do you know what entropy means, in the context of machine learning? If not, then this post is for you. In this introduction, I’ll carefully unpack the concepts and mathematics behind entropy, cross entropy and a related concept, KL divergence, to give you a better foundational understanding of these important ideas. For starters, let’s look at the concept of entropy.

Entropy

The term entropy originated in statistical thermodynamics, which is a sub-domain of physics. However, for machine learning, we are more interested in the entropy as defined in information theory or Shannon entropy. This formulation of entropy is closely tied to the allied idea of information. Entropy is the average rate of information produced from a certain stochastic process (see here). As such, we first need to unpack what the term “information” means in an information theory context. If you’re feeling a bit lost at this stage, don’t worry, things will become much clearer soon.


Eager to build deep learning systems? Get the book here


Information content

Information I in information theory is generally measured in bits, and can loosely, yet instructively, be defined as the amount of “surprise” arising from a given event. To take a simple example – imagine we have an extremely unfair coin which, when flipped, has a 99% chance of landing heads and only 1% chance of landing tails. We can represent this using set notation as {0.99, 0.01}.

If we toss the coin once, and it lands heads, we aren’t very surprised and hence the information “transmitted” by such an event is low. Alternatively, if it lands tails, we would be very surprised (given what we know about the coin) and therefore the information of such an event would be high. This can be represented mathematically by the following formula:

$$I(E) = -log[Pr(E)] = -log(P)$$

This equation gives the information entailed in a stochastic event E, which is given by the negative log of the probability of the event. This can be expressed more simply as $-log(P)$. One thing to note – if we are dealing with information expressed in bits (i.e. each bit is either 0 or 1) the logarithm has a base of 2 – so $I(E) = -log_{2}(P)$. An alternative unit often used in machine learning is nats, and applies where the natural logarithm is used.

For the coin toss event with our unfair coin, the information entailed by heads would be $-log_{2}(0.99) = 0.0144bits$, which is quite low. Alternatively, for tails the information is equal to 6.64bits. So this lines up nicely with the interpretation of information as “surpise” as discussed above.

Now, recall that entropy is defined as the average rate of information produced from a stochastic process. What is the average or expected rate of information produced from the process of flipping our very unfair coin?

Entropy and information

How do we calculate the expected or average value of something? Recall that the expected value of a variable X is given by:

$$E[X] = \sum_{i=1}^{n} x_{i}p_{i}$$

Where $x_{i}$ is the i-th possible value of $x$ and $p_{i}$ is the probability of that $x$ occurring. Likewise, we can define the entropy as the expected value of information, so that it looks like this:

$$H(X) = E[I(X)] = E[-log(P(X))] = -\sum_{i=1}^{n}P(x_{i})logP(x_{i})$$

We can now apply this to the unfair coin question like so:

$$H(X) = -(0.99log(0.99) + 0.01log(0.01)) = 0.08bits$$

It can therefore be said that the unfair coin is a stochastic information generator which has an average information delivery rate of 0.08bits. This is quite small, because it is dominated by the high probability of the {heads} outcome. However it should be noted that a fair coin will give one an entropy of 1bit. In any case, this should give you a good idea of what entropy is, what it measures and how to calculate it.

So far so good. However, what does this have to do with machine learning?

Entropy and machine learning

You might be wondering at this point where entropy is used in machine learning. Well, first of all, it is a central concept of cross entropy which you are probably already familiar with. More on that later. However, entropy is also used in its own right within machine learning. One notable and instructive instance is its use in policy gradient optimization in reinforcement learning. In such a case, a neural network is trained to control an agent, and its output consists of a softmax layer. This softmax output layer is a probability distribution of what the best action for the agent is.

The output, for an environment with an action size of 4, may look something like this for a given game state:

{0.9, 0.05, 0.025, 0.025}

In the case above, the agent will most likely choose the first action (i.e. p = 0.9). So far so good – but how does entropy come into this? One of the key problems which needs to be addressed in reinforcement learning is making sure the agent doesn’t learn to converge on one set of actions or strategies too quickly. This is called encouraging exploration. In the policy gradient version of reinforcement learning, exploration can be encouraged by putting the negative of the entropy of the output layer into the loss function. Thus, as the loss is minimized, any “narrowing down” of the probabilities of the agent’s actions must be strong enough to counteract the increase in the negative entropy.

For instance, the entropy of the example output above, given its predilection for choosing the first action (i.e. p = 0.9), is quite low – 0.61bits. Consider an alternative output of the softmax layer however:

{0.3, 0.2, 0.25, 0.25}

In this case, the entropy is larger 1.98bits, given that there is more uncertainty in what action the agent should choose. If the negative of entropy is included in the loss function, a higher entropy will act to reduce the loss value more than a lower entropy, and hence there will be a tendency not to converge too quickly on a definitive set of actions (i.e. low entropy).

Entropy is also used in certain Bayesian methods in machine learning, but these won’t be discussed here. It is now time to consider the commonly used cross entropy loss function.

Cross entropy and KL divergence

Cross entropy is, at its core, a way of measuring the “distance” between two probability distributions P and Q. As you observed, entropy on its own is just a measure of a single probability distribution. As such, if we are trying to find a way to model a true probability distribution, P, using, say, a neural network to produce an approximate probability distribution Q, then there is the need for some sort of distance or difference measure which can be minimized.

The difference measure in cross entropy arises from something called KullbackÔÇôLeibler (KL) divergence. This can be seen in the definition of the cross entropy function:

$$H(p, q) = H(p) + D_{KL}(p \parallel q)$$

The first term, the entropy of the true probability distribution p, during optimization is fixed – it reduces to an additive constant during optimization. It is only the parameters of the second, approximation distribution, q that can be varied during optimization – and hence the core of the cross entropy measure of distance is the KL divergence function.

KL divergence

The KL divergence between two distributions has many different interpretations from an information theoretic perspective. It is also, in simplified terms, an expression of “surprise” – under the assumption that P and Q are close, it is surprising if it turns out that they are not, hence in those cases the KL divergence will be high. If they are close together, then the KL divergence will be low.

Another interpretation of KL divergence, from a Bayesian perspective, is intuitive – this interpretation says KL divergence is the information gained when we move from a prior distribution Q to a posterior distribution P. The expression for KL divergence can also be derived by using a likelihood ratio approach.

The likelihood ratio can be expressed as:

$$LR = \frac{p(x)}{q(x)}$$

This can be interpreted as follows: if a value x is sampled from some unknown distribution, the likelihood ratio expresses how much more likely the sample has come from distribution p than from distribution q. If it is more likely from p, the LR > 1, otherwise if it is more likely from q, the LR < 1.

So far so good. Let’s say we have lots of independent samples and we want to estimate the likelihood function taking into account all this evidence – it then becomes:

$$LR = \prod_{i=0}^{n}\frac{p(x_{i})}{q(x_{i})}$$

If we convert the ratio to log, it’s possible to turn the product in the above definition to a summation:

$$LR = \sum_{i=0}^{n}log\left(\frac{p(x_{i})}{q(x_{i})}\right)$$

So now we have the likelihood ratio as a summation. Let’s say we want to answer the question of how much, on average, each sample gives evidence of $p(x)$ over $q(x)$. To do this, we can take the expected value of the likelihood ratio and arrive at:

$$D_{KL}(P\parallel Q) = \sum_{i=0}^{n}p(x_{i})log\left(\frac{p(x_{i})}{q(x_{i})}\right)$$

The expression above is the definition of KL divergence. It is basically the expected value of the likelihood ratio – where the likelihood ratio expresses how much more likely the sampled data is from distribution P instead of distribution Q. Another way of expressing the above definition is as follows (using log rules):

$$D_{KL}(P\parallel Q) = \sum_{i=0}^{n}p(x_{i})log (p(x_{i})) – \sum_{i=0}^{n}p(x_{i})log (q(x_{i}))$$

The first term in the above equation is the entropy of the distribution P. As you can recall it is the expected value of the information content of P. The second term ($\sum_{i=0}^{n}p(x_{i})log (q(x_{i}))$) is the information content of Q, but instead weighted by the distribution P. This yields the interpretation of the KL divergence to be something like the following – if P is the “true” distribution, then the KL divergence is the amount of information “lost” when expressing it via Q.

However you wish to interpret the KL divergence, it is clearly a difference measure between the probability distributions P and Q. It is only a “quasi” distance measure however, as $P_{KL}(P \parallel Q) \neq
P_{KL}(Q \parallel P)$.

Now we need to show how the KL divergence generates the cross-entropy function.

Cross entropy

As explained previously, the cross entropy is a combination of the entropy of the “true” distribution P and the KL divergence between P and Q:

$$H(p, q) = H(p) + D_{KL}(p \parallel q)$$

Using the definition of the entropy and KL divergence, and log rules, we can arrive at the following cross entropy definition:

$$H(p, q) = – \sum_{i=0}^{n}p(x_{i})log (q(x_{i}))$$

What does the utilization of this function look like in practice for a classification task in neural networks? In such a task, we are usually dealing with the true distribution P being a one-hot encoded vector. So, for instance, in the MNIST hand-written digit classification task, if the image represents a hand-written digit of “2”, P will look like:

{0, 0, 1, 0, 0, 0, 0, 0, 0, 0}

The output layer of our neural network in such a task will be a softmax layer, where all outputs have been normalized so they sum to one – representing a quasi probability distribution. The output layer, Q, for this image could be:

{0.01, 0.02, 0.75, 0.05, 0.02, 0.1, 0.001, 0.02, 0.009, 0.02}

To get the predicted class, one would run an np.argmax over the output and, in this example, we would get the correct prediction. However, observe how the cross entropy loss function works in this instance. For all values apart from i=2, $p(x_{i}) = 0$, so the value within the summation for these indices falls to 0. The only index which doesn’t have a zero value is i=2. As such, for one-hot encoded vectors, the cross entropy collapses to:

$$H(p,q) = -log(q(x_{i}))$$

In this example, the cross entropy loss would be $-log(0.75) = 0.287$ (using nats as the information unit). The closer the Q value gets to 1 for the i=2 index, the lower the loss would get. This is because the KL divergence between P and Qis reducing for this index.

One might wonder – if the cross entropy loss for classification tasks reduces to a single output node calculation, how does the neural network learn to both increase the softmax value that corresponds to the true index, and decrease the values of all the other nodes? It does this via the cross interaction of nodes through the weights, but also, through the nature of the softmax function itself – if a single index is encouraged to increase, all the other indices/output classes will be encouraged to decrease in the softmax function.

In TensorFlow 2.0, the function to use to calculate the cross entropy loss is the tf.keras.losses.CategoricalCrossentropy() function, where the P values are one-hot encoded. If you’d prefer to leave your true classification values as integers which designate the true values (rather than one-hot encoded vectors), you can use instead the tf.losses.SparseCategoricalCrossentropy() function. In PyTorch, the function to use is torch.nn.CrossEntropyLoss() – however, note that this function performs a softmax transformation of the input before calculating the cross entropy – as such, one should supply only the “logits” (the raw, pre-activated output layer values) from your classifier network. The TensorFlow functions above require a softmax activation to already be applied to the output of the neural network.

That concludes this tutorial on the important concepts of entropy, cross entropy and KL divergence. I hope it will help you deepen your understanding of these commonly used functions, and thereby deepen your understanding of machine learning and neural networks.


Eager to build deep learning systems? Get the book here


Leave a Reply

Your email address will not be published. Required fields are marked *