Beyond ‘training’: Dynamic Context Propagation in Keras!
Understanding 'training', and call-context argument propagation in Keras 3
If you've built custom layers or explored Keras internals, you've probably come across the training argument. It's a boolean flag that magically tells layers like Dropout or BatchNormalization whether to behave in their training mode (e.g., applying dropout, updating moving averages) or inference mode (e.g., disabling dropout, using moving averages seen during training).
Consider this simple illustration using nested custom layers within a Functional Model
When the Keras Functional model is called with training=True, this training argument is automatically passed down through OuterLayerCallingInner to InnerLayerWithTrainingLogic, causing it to execute its training-specific logic. Conversely, training=False ensures the inference path is taken.
But have you ever wondered how Keras seamlessly propagates this argument, especially through complex, nested model architectures? And more importantly, have you ever wished you could have a similar mechanism for your own custom arguments?
In this article, I’ll be explaining how the training argument works, and how you can use the new call-context arguments to replicate that behavior for your own custom arguments!
‘training’ Magic!
Before we jump into custom arguments, let's appreciate the elegance of how training is handled. When you call a Keras layer or model, Keras doesn't just blindly pass the training argument you might have provided to every layer’s call method. It intelligently resolves its value based on a clear priority:
User-defined value: If you explicitly pass
training=Trueortraining=Falsein thecallmethod definition, that takes precedence.Inherited context: If no explicit value is given, Keras looks at the context from an outer layer or model call. This is how, for example,
model.fit()implicitly setstraining=Truefor all inner layers andmodel.predict()sets it toFalse.Default value: If neither of the above is present, the default value defined in the layer's call method signature is used.
This resolved value is then stored in an internal, call-stack-aware context object (CallContext), which allows it to be automatically propagated to any nested layers without you needing to manually thread the training argument through every single call signature. If a layer's call method actually declares training in its signature, Keras populates it with this resolved value.
We’ll discuss how this context handling works later in the article.
What if ‘training’ is not enough?
While the training argument is fundamental, advanced use cases sometimes require other contextual arguments to be propagated with similar ease.
Imagine wanting to adjust a noise_injection_scale on the fly during robustness experiments without recompiling your model. You might be dynamically choosing a dropout_variant within a single versatile layer, or even guiding a layer to adopt a specific knowledge_distillation_role such as 'teacher_logits' or 'student_match' in a complex training pipeline.
Manually threading these kinds of situational flags through every layer in a deep stack, especially when dealing with nested models, can quickly become unwieldy. It requires users to deal with additional boilerplate code, possibility of user-error (forgetting to pass an argument somewhere), and additional cognitive load.
Such use cases motivated the idea of a generalized call-context argument propagation mechanism, which allows developers to easily declare ‘training’-like custom arguments which magically propagate through the model.
Call-Context Arguments
Keras now provides a clean and powerful way for layer developers to define their own call-context arguments. Model authors can leverage a special instance method in their custom layers called _register_call_context_arguments to declare and register any custom arguments which they wish to be propagated across the model.
Here’s how it works:
Declare your context arguments: In your custom layer's
__init__method (the layer that uses the context argument), callself._register_call_context_arguments(...)with the argument names. This tells Keras this specific layer instance is interested in potentially receiving these arguments from the context.Define in
callsignature: Your layer’scallmethod should include these custom arguments in its signature, typically with a default value.Register the context arguments with the caller (model or a top-level layer): On the layer or model instance that will initiate the propagation (i.e., where you will pass the argument in its call), also
call _register_call_context_arguments(...). This tells Keras that calls to this specific object might include these arguments, and they should be placed into the context for potential downstream consumption.Keras handles the rest: The base
Layer.__call__method will:Remember the registered call-context arguments.
Resolve their values using the same priority resolution mechanism as
training:Value explicitly passed by the user in the current layer call.
Value inherited from the context of an outer layer.
Default value from the
callmethod signature of the current layer.
Store the resolved value in the current call-stack’s context for propagation to nested layers.
Populate the argument in your layer’s call method if it’s defined in the signature and the resolved value is not None.
Note: You need two registrations for propagation to work:
The consuming layer registers its interest in
__init__.The top-level caller (where the value is initially provided) registers the argument names to enable the propagation mechanism for its call stack. Intermediate layers that merely pass the context through do not need to register the arguments.
A Motivating Example
Dynamic Dropout Intensity
Let’s consider a scenario where you want to dynamically adjust the “intensity” of dropout for specific layers deep within your model during different phases of training or for experimentation, without altering their initial configuration.
We’ll create:
ConfigurableDropoutLayer: A dropout layer that can change its behaviour based on adropout_intensity_modepassed to its call. This is the layer which consumes the context argument.PassThroughBlock: An intermediate layer which uses theConfigurableDropoutLayerbut is itself oblivious todropout_intensity_mode. It doesn’t declare or use it.
The magic is that dropout_intensity_mode will flow through PassThroughBlock to ConfigurableDropoutLayer if set at a higher level.
Argument Propagation
Because
ConfigurableDropoutLayerregistereddropout_intensity_modein its__init__, itscallknows to look for it in the context.Because the model registered the same argument via
model._register_call_context_arguments(...), the value provided inmodel(...)is actually placed into the context to begin with.PassThroughBlockis just an intermediary. It containsConfigurableDropoutLayerbut doesn’t concern itself with its internal workings, therefore it doesn’t directly usedropout_intensity_modefor its own logic.
When we invoke model(data, training=True, dropout_intensity_mode=”aggressive”):
The
dropout_intensity_mode=aggressiveis introduced into the context at the beginning of this top-level call.When
PassThroughBlockis called, its__call__method (from the basekeras.Layer) sees this context. SincePassThroughBlockdoesn’t registerdropout_intensity_modefor itself, it doesn’t consume or alter it. The context simply remains active.When
PassThroughBlockthen calls its internalself.contextual_dropout(which is an instance of theConfigurableDropoutLayer), that layer’s__call__method does check fordropout_intensity_modein the context (because it declared it) and uses the value (“aggressive”) found there.
The context flows effortlessly through intermediate unaware layers to the specific layers designed to act upon it. Only layers that produce or consume a custom call-context argument need to be aware about it.
This makes the feature very low-friction for building complex, adaptable models.
Under the Hood
When a layer is instantiated, it registers the call-context arguments provided in the
_register_call_context_arguments()call, and merges them with the built-in{“training”}.Internally, Keras manages this using a
CallContextobject, implemented via a thread-local global state. This ensures that context values set in an outer call are available to nested calls within the same forward pass but isolated between different top-level calls.During
Layer.__call__, it iterates through this combined list of context-aware arguments. For each one, it resolves the value (user input > context > signature default) using theCallContextobject and then updates theCallContextfor downstream layers.If the current layer’s
callsignature actually includes the argument and its resolved value isn’tNone, it’s injected into thekwargsfor the actualcallmethod invocation.The
FunctionalandSequentialcall methods also correctly pass through**kwargs, ensuring these custom call arguments reach the underlying layers.
Final Thoughts
We went over a simple example to demonstrate the usefulness of call-context arguments. It’s easy to come up with more advanced use-cases involving profiling, dynamic layer activations, debugging-utilities, and more.
It’s interesting to see how much complexity Keras magically abstracts away from the user, so the model authors can focus on building models instead of grunt work like threading contextual arguments through layer hierarchies.
In case you’re interested in checking out the implementation of this feature in code, take a look at the following PRs:
keras-team/keras#21204 (initial support)
keras-team/keras#21222 (API improvements)




![import keras from keras import layers import numpy as np class InnerLayerWithTrainingLogic(layers.Layer): def call(self, inputs, training=None): if training: return inputs + 1 # Behavior specific to training else: return inputs # Behavior specific to inference class OuterLayerCallingInner(layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.inner_layer = InnerLayerWithTrainingLogic() # OuterLayer's call doesn't even need to declare 'training'. # Keras handles propagation if 'training' is passed to OuterLayer. def call(self, inputs): return self.inner_layer(inputs) input_tensor = keras.Input(shape=(2,)) output_tensor = OuterLayerCallingInner()(input_tensor) model = keras.Model(inputs=input_tensor, outputs=output_tensor) sample_data = np.array([[10.0, 20.0]]) # Call the model with training=True # InnerLayerWithTrainingLogic should add 1 # Expected: [[11.0, 21.0]] output_train_mode = model(sample_data, training=True) # Call the model with training=False (or omit, as in model.predict()) # InnerLayerWithTrainingLogic should not add 1 # Expected: [[10.0, 20.0]] output_inference_mode = model(sample_data, training=False) import keras from keras import layers import numpy as np class InnerLayerWithTrainingLogic(layers.Layer): def call(self, inputs, training=None): if training: return inputs + 1 # Behavior specific to training else: return inputs # Behavior specific to inference class OuterLayerCallingInner(layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.inner_layer = InnerLayerWithTrainingLogic() # OuterLayer's call doesn't even need to declare 'training'. # Keras handles propagation if 'training' is passed to OuterLayer. def call(self, inputs): return self.inner_layer(inputs) input_tensor = keras.Input(shape=(2,)) output_tensor = OuterLayerCallingInner()(input_tensor) model = keras.Model(inputs=input_tensor, outputs=output_tensor) sample_data = np.array([[10.0, 20.0]]) # Call the model with training=True # InnerLayerWithTrainingLogic should add 1 # Expected: [[11.0, 21.0]] output_train_mode = model(sample_data, training=True) # Call the model with training=False (or omit, as in model.predict()) # InnerLayerWithTrainingLogic should not add 1 # Expected: [[10.0, 20.0]] output_inference_mode = model(sample_data, training=False)](https://substackcdn.com/image/fetch/$s_!po4v!,w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F83511f81-3fbb-49fa-8ac8-2beeb9f25048_2676x3456.png)

