Keras Internals: Debugging a Softmax Behavior Inconsistency
Exploring the Keras Softmax implementation details and fixing a subtle behavior inconsistency with masking.
As an engineer on the Keras team, I often encounter interesting technical details. Since Keras is an open-source project, I have the opportunity to share some of these insights with the community.
For the uninitiated, Keras is a multi-backend deep learning library. It provides a powerful, and easy to use API on top of a variety of backends, such as JAX, PyTorch, and TensorFlow.
Recently, a behavioral discrepancy was identified between the Keras Softmax implementation, particularly when using input masking, and its counterpart in JAX. Ensuring consistent behavior across different computational backends (such as JAX, TensorFlow, and PyTorch) is a fundamental requirement for a multi-backend framework like Keras.
As a part of the investigation I ended up peeking into the internals of how Keras implements Softmax, and it was interesting to see the design choices made to accommodate a multi-backend API.
Softmax Primer
Let’s take a small primer on what the softmax function does and how the mask argument changes its behavior. Then we can discuss the behavioral inconsistencies that the Softmax implementation was exhibiting.
Softmax is a function commonly used in the final layer of neural networks for classification tasks.
It takes a vector of arbitrary real-valued scores (often called logits) and transforms them into a probability distribution.
Essentially, it converts these scores into probabilities that sum up to 1, where each probability corresponds to the likelihood of a particular class.
Higher input scores result in higher output probabilities.
Mask
The mask argument adds a layer of control. It allows you to specify which elements of the input vector should be included in the softmax calculation. When an element is masked (typically indicated by False in the mask array), it's effectively excluded from the probability calculation.
In practice, this is often achieved by assigning a very large negative value to the masked elements before applying the exponential function. This ensures their contribution to the final probability distribution is zero (or numerically very close to zero), and the remaining probabilities are normalized across only the unmasked elements.
Behavior Divergence with JAX
GitHub issue: Softmax layer diverges from jax.nn.softmax · Issue #21123 · keras-team/keras · GitHub
What happens when someone masks an entire axis while calculating the softmax, any guesses?
Let’s look at the code.
We create a random tensor sampled from a uniform random distribution, of shape
(1, 4, 4)
.We define a boolean mask, with one of the axes completely masked out (all positions set to
False
)Calculate the softmax activation for the tensor with the default axis (-1) and pass the mask as an argument to the
where
parameter of thejax.nn.softmax
function.As you might guess, we get a softmax projection of the original tensor, where each numeric in its respective axis represents a probability. All the logits across an axis must sum up to 1.
Also notice that the masked inputs are all zeroed out, as one might expect.
Let’s try to recreate this operation with the same input tensor and mask, while using Keras with JAX backend.
We start by initializing a softmax layer (again on the default axis, -1)
We pass the random tensor to the layer, along with the mask
The outputs look similar, except for the last row. For some reason instead of being zeroed out, all the elements are set to 0.25!
Why do you think that is? Hint, it has something to do with how masking is implemented. If you look closely, returning 0.25 across all the four masked positions is similar to returning a uniform probability distribution over the entire axis (each position has the same probability).
There is no hard guidance around what the output of a softmax layer should look like when using masking, and it’s interesting to see this behavior discrepancy between JAX and Keras when masking an entire axis.
It’s not possible to contrast these outputs with TensorFlow or PyTorch, since their softmax implementations do not support masking. A natural choice at this point is to treat the JAX softmax behavior as canonical, and ensure Keras follows it.
How Keras implements Softmax
Let’s understand the backend agnostic softmax implementation which keras provides. Often, each backend may differ slightly in the APIs it offers, leading to some interesting design choices.
Let’s analyze the call()
method of keras.layers.Softmax
, since that’s where the interesting stuff happens. I’ve added comments to the function to explain all the details.
Note: Keras uses a
backend
object whenever it wants to delegate computation to a backend-specific API. The backend object encapsulates the user-selected backend (JAX, TF, PyTorch). This allows all implementations to rely on a simple backend-based abstraction, without having to worry about how each computational backend handles a specific task.
Adding Masks
The intuition behind how softmax masking is implemented is simple. We add a very large negative number to all the masked positions; effectively negating their contribution to the softmax computation (since the probability of now choosing them is so low, that it’s practically zero). Therefore, numerically, they no longer contribute to the probability distribution.
Calculating Softmax over multiple Axes
An interesting detail is that while JAX supports multi-axis softmax, TensorFlow and PyTorch don’t. Therefore keras provides its own implementation for this case.
This is another cool reason to use keras, you get additional bells and whistles on top of functionality you already understand well, without requiring you to implement it yourself.
We use something called the log-sum-exp trick to calculate the softmax projection over multiple axes.
I won’t being going very deep into the mathematical derivation (even though it’s quite simple) for why this trick works in this article, because to-be-fair, I’m not the best math-guy out there (yet) to confidently be able to write derivations in my substack articles. However, I’ll nudge you towards an excellent article (< 4 min read) which gives you a solid mathematical grounding for this. You can find it here.
All the backends already support the logsumexp operation natively, which means this computation will be hardware-accelerated and well-optimized.
Implementing the logsumexp trick
backend.math.logsumexp(inputs, axis=self.axis, keepdims=True)
This calculates
log(sum(exp(inputs)))
across the specified axis (or axes).
Crucially, the
logsumexp
function itself is implemented internally in a numerically stable way, often by subtracting the maximum value before exponentiating:logsumexp(x) = max(x) + log(sum(exp(x - max(x))))
. This prevents the intermediate exp calculations from overflowing. Note, the article I linked above derives this too.axis=self.axis
: Tells the function which dimension(s) to sum over.keepdims=True
: When summing over an axis (or axes), that dimension usually disappears (since we reduce over it).keepdims=True
ensures that the summed axes are retained with a size of 1. This makes the shape of the logsumexp result compatible for broadcasting when subtracted from the original inputs tensor.
inputs - backend.math.logsumexp(...)
This performs the subtraction
inputs - log(sum(exp(inputs)))
. Because logsumexp's result was computed stably and has compatible dimensions (due tokeepdims=True
), this subtraction is well-defined.The values here are essentially the logarithms of the final softmax probabilities.
backend.numpy.exp(...)
This takes the exponent of the result from the previous step, effectively converting the stable log-probabilities back into the actual softmax probabilities:
exp(log(softmax(inputs))) = softmax(inputs)
.
In case of a single-axis softmax computation, we simply call the activations.softmax(…)
method. It internally calls keras.ops.softmax(x, axis=axis)
.
Let’s understand how that’s implemented. As earlier, I’ve added comments to all the lines
After some initial validation, we check whether we’re operating over symbolic tensors (when no actual inputs have been provided) and handle those accordingly.
We check whether we are computing softmax over a single axis, or over multiple axes. We won’t hit this branch in our case, since we already handled that at the Softmax Layer level (as discussed above). It’s still interesting to see how this is implemented.
The gist is that we temporarily rearrange and flatten the specific axes we’re interested in into a single dimension at the end. We then apply the standard single-axis softmax to this flattened dimension before reversing the rearrangement to restore the original tensor shape.The last else condition is what we’ll usually be dealing with in case of
keras.layers.Softmax
. It delegates the single-axis softmax calculation to the respective backend that we might be using.
The Masking inconsistency
That was a lot about how Keras implements Softmax. Now let’s try and understand why keras outputs a uniform probability distribution when we mask an an entire axis.
We’ll need to go back to how masking is implemented. We add a large number to all masked positions to drop their contribution to the probability distribution to practically zero when compared to all other positions.
But what happens when we do this at all the positions across a softmax axis?
Driving all logits towards negative infinity makes them numerically indistinguishable from each other. When softmax exponentiates these similar, large negative inputs, the outputs become tiny, near-identical positive numbers. Normalizing these values inevitably leads to an equal probability assigned to each position – a uniform distribution.
The Fix
It’s quite easy to fix this issue once you know why it happens. All we need to do is to multiply the final outputs, whatever they are, with the original boolean mask. This ensures that all the masked outputs are always set to zero.
The following figure shows you how we need to modify the existing implementation.
Conclusion
We went over how Keras implements the softmax layer, and how sometimes computational tricks can lead to unexpected results! It’s often not hard to fix these issues if you truly understand how things are implemented under the hood.
See you in the next one!