Beyond ‘training’: Dynamic Context Propagation in Keras!
Understanding 'training', and call-context argument propagation in Keras 3
If you've built custom layers or explored Keras internals, you've probably come across the training
argument. It's a boolean flag that magically tells layers like Dropout
or BatchNormalization
whether to behave in their training mode (e.g., applying dropout, updating moving averages) or inference mode (e.g., disabling dropout, using moving averages seen during training).
Consider this simple illustration using nested custom layers within a Functional Model
When the Keras Functional model is called with training=True
, this training
argument is automatically passed down through OuterLayerCallingInner
to InnerLayerWithTrainingLogic
, causing it to execute its training-specific logic. Conversely, training=False
ensures the inference path is taken.
But have you ever wondered how Keras seamlessly propagates this argument, especially through complex, nested model architectures? And more importantly, have you ever wished you could have a similar mechanism for your own custom arguments?
In this article, I’ll be explaining how the training
argument works, and how you can use the new call-context arguments to replicate that behavior for your own custom arguments!
‘training’ Magic!
Before we jump into custom arguments, let's appreciate the elegance of how training is handled. When you call a Keras layer or model, Keras doesn't just blindly pass the training argument you might have provided to every layer’s call method. It intelligently resolves its value based on a clear priority:
User-defined value: If you explicitly pass
training=True
ortraining=False
in thecall
method definition, that takes precedence.Inherited context: If no explicit value is given, Keras looks at the context from an outer layer or model call. This is how, for example,
model.fit()
implicitly setstraining=True
for all inner layers andmodel.predict()
sets it toFalse
.Default value: If neither of the above is present, the default value defined in the layer's call method signature is used.
This resolved value is then stored in an internal, call-stack-aware context object (CallContext
), which allows it to be automatically propagated to any nested layers without you needing to manually thread the training
argument through every single call signature. If a layer's call method actually declares training in its signature, Keras populates it with this resolved value.
We’ll discuss how this context handling works later in the article.
What if ‘training’ is not enough?
While the training
argument is fundamental, advanced use cases sometimes require other contextual arguments to be propagated with similar ease.
Imagine wanting to adjust a noise_injection_scale
on the fly during robustness experiments without recompiling your model. You might be dynamically choosing a dropout_variant
within a single versatile layer, or even guiding a layer to adopt a specific knowledge_distillation_role
such as 'teacher_logits'
or 'student_match'
in a complex training pipeline.
Manually threading these kinds of situational flags through every layer in a deep stack, especially when dealing with nested models, can quickly become unwieldy. It requires users to deal with additional boilerplate code, possibility of user-error (forgetting to pass an argument somewhere), and additional cognitive load.
Such use cases motivated the idea of a generalized call-context argument propagation mechanism, which allows developers to easily declare ‘training’-like custom arguments which magically propagate through the model.
Call-Context Arguments
Keras now provides a clean and powerful way for layer developers to define their own call-context arguments. Model authors can leverage a special instance method in their custom layers called _register_call_context_arguments
to declare and register any custom arguments which they wish to be propagated across the model.
Here’s how it works:
Declare your context arguments: In your custom layer's
__init__
method (the layer that uses the context argument), callself._register_call_context_arguments(...)
with the argument names. This tells Keras this specific layer instance is interested in potentially receiving these arguments from the context.Define in
call
signature: Your layer’scall
method should include these custom arguments in its signature, typically with a default value.Register the context arguments with the caller (model or a top-level layer): On the layer or model instance that will initiate the propagation (i.e., where you will pass the argument in its call), also
call _register_call_context_arguments(...)
. This tells Keras that calls to this specific object might include these arguments, and they should be placed into the context for potential downstream consumption.Keras handles the rest: The base
Layer.__call__
method will:Remember the registered call-context arguments.
Resolve their values using the same priority resolution mechanism as
training
:Value explicitly passed by the user in the current layer call.
Value inherited from the context of an outer layer.
Default value from the
call
method signature of the current layer.
Store the resolved value in the current call-stack’s context for propagation to nested layers.
Populate the argument in your layer’s call method if it’s defined in the signature and the resolved value is not None.
Note: You need two registrations for propagation to work:
The consuming layer registers its interest in
__init__
.The top-level caller (where the value is initially provided) registers the argument names to enable the propagation mechanism for its call stack. Intermediate layers that merely pass the context through do not need to register the arguments.
A Motivating Example
Dynamic Dropout Intensity
Let’s consider a scenario where you want to dynamically adjust the “intensity” of dropout for specific layers deep within your model during different phases of training or for experimentation, without altering their initial configuration.
We’ll create:
ConfigurableDropoutLayer
: A dropout layer that can change its behaviour based on adropout_intensity_mode
passed to its call. This is the layer which consumes the context argument.PassThroughBlock
: An intermediate layer which uses theConfigurableDropoutLayer
but is itself oblivious todropout_intensity_mode
. It doesn’t declare or use it.
The magic is that dropout_intensity_mode
will flow through PassThroughBlock
to ConfigurableDropoutLayer
if set at a higher level.
Argument Propagation
Because
ConfigurableDropoutLayer
registereddropout_intensity_mode
in its__init__
, itscall
knows to look for it in the context.Because the model registered the same argument via
model._register_call_context_arguments(...)
, the value provided inmodel(...)
is actually placed into the context to begin with.PassThroughBlock
is just an intermediary. It containsConfigurableDropoutLayer
but doesn’t concern itself with its internal workings, therefore it doesn’t directly usedropout_intensity_mode
for its own logic.
When we invoke model(data, training=True, dropout_intensity_mode=”aggressive”)
:
The
dropout_intensity_mode=aggressive
is introduced into the context at the beginning of this top-level call.When
PassThroughBlock
is called, its__call__
method (from the basekeras.Layer
) sees this context. SincePassThroughBlock
doesn’t registerdropout_intensity_mode
for itself, it doesn’t consume or alter it. The context simply remains active.When
PassThroughBlock
then calls its internalself.contextual_dropout
(which is an instance of theConfigurableDropoutLayer
), that layer’s__call__
method does check fordropout_intensity_mode
in the context (because it declared it) and uses the value (“aggressive”
) found there.
The context flows effortlessly through intermediate unaware layers to the specific layers designed to act upon it. Only layers that produce or consume a custom call-context argument need to be aware about it.
This makes the feature very low-friction for building complex, adaptable models.
Under the Hood
When a layer is instantiated, it registers the call-context arguments provided in the
_register_call_context_arguments()
call, and merges them with the built-in{“training”}
.Internally, Keras manages this using a
CallContext
object, implemented via a thread-local global state. This ensures that context values set in an outer call are available to nested calls within the same forward pass but isolated between different top-level calls.During
Layer.__call__
, it iterates through this combined list of context-aware arguments. For each one, it resolves the value (user input > context > signature default) using theCallContext
object and then updates theCallContext
for downstream layers.If the current layer’s
call
signature actually includes the argument and its resolved value isn’tNone
, it’s injected into thekwargs
for the actualcall
method invocation.The
Functional
andSequential
call methods also correctly pass through**kwargs
, ensuring these custom call arguments reach the underlying layers.
Final Thoughts
We went over a simple example to demonstrate the usefulness of call-context arguments. It’s easy to come up with more advanced use-cases involving profiling, dynamic layer activations, debugging-utilities, and more.
It’s interesting to see how much complexity Keras magically abstracts away from the user, so the model authors can focus on building models instead of grunt work like threading contextual arguments through layer hierarchies.
In case you’re interested in checking out the implementation of this feature in code, take a look at the following PRs:
keras-team/keras#21204 (initial support)
keras-team/keras#21222 (API improvements)